In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import PROJECT.util as f
from matplotlib.pyplot import cm
from mpl_toolkits import mplot3d
import plotly.graph_objects as go
import pandas as pd

In [None]:
def mse1(u_e, v_e, u_t, v_t):
    a=torch.pow((u_e-u_t),2).mean(0)
    b=torch.pow((v_e-v_t),2).mean(0)
    return a+b
def mse2(u_e, v_e, u_t, v_t):
    ## u*v*-uv
    tmp = torch.tensordot(u_e, v_e, 0) - torch.tensordot(u_t, v_t, 0)
    return torch.mean(torch.pow((tmp), 2))
def mse3(u_e, v_e, u_t, v_t, Y, N, M, lambda_):
    ## Y-uv
    tmp = Y - (torch.sqrt(lambda_/N)*torch.tensordot(u_t, v_t, 0))
    return torch.mean(torch.pow((tmp), 2))

In [None]:
def main_gradient_2_sans_proj_normalisation(iteration, u_p, v_p, Y, u_, v_, N, M, lambda_, beta_u, beta_v, lambda_1, lambda_2, dt):

    res = []
    res_mse= []

    for i in range(iteration):

        # Computation
        sqrt_dt = torch.sqrt(dt)

        u_1 = (1/lambda_1) * f.gradient_u_2(N, M, u_p,v_p,Y,lambda_) * dt
        u_2 = torch.sqrt(2/(lambda_1*beta_u)) * torch.empty(N).normal_(mean=0,std=sqrt_dt)
        u_3 = ((N-1)/(N*lambda_1*beta_u))*u_p*dt
        u_n = u_p - u_1 + u_2 - u_3

        v_1 = (1/lambda_2) * f.gradient_v_2(N,M,u_p,v_p,Y,lambda_) * dt
        v_2 = torch.sqrt(2/(lambda_2*beta_v)) * torch.empty(M).normal_(mean=0,std=sqrt_dt)
        v_3 = ((M-1)/(M*lambda_2*beta_v))*v_p*dt
        v_n = v_p - v_1 + v_2 - v_3

        # Normalisation
        u_n = u_n / torch.linalg.norm(u_n)
        v_n = v_n / torch.linalg.norm(v_n)
        u_n = u_n * torch.sqrt(N)
        v_n = v_n * torch.sqrt(M)

        # Re-asign for the loop
        u_p = u_n
        v_p = v_n
    
        res_u = torch.abs(f.overlap(u_,u_n,N))
        res_v = torch.abs(f.overlap(v_,v_n,M))
        res.append((res_u,res_v))
        
        res_mse.append((mse1(torch.abs(u_),torch.abs(v_),torch.abs(u_n),torch.abs(v_n)),
                        mse2(u_,v_,u_n,v_n),
                       mse3(u_,v_,u_n,v_n,Y,N,M,lambda_)))

    return np.array(res), np.array(res_mse)

In [None]:
def one_run(lambda_=2, N=200, M=200, beta_u=float("inf"), beta_v=float("inf"), lambda_1=1,
        lambda_2=1, dt=1/100, iteration=5000):
      #Size of the matrix Y
    N = torch.tensor(N)
    M = torch.tensor(M)

    # lambda
    lambda_ = torch.tensor(lambda_)

    #temperatures
    beta_u = torch.tensor(beta_u)
    beta_v = torch.tensor(beta_v)

    #learning rates
    lambda_1 = torch.tensor(lambda_1)
    lambda_2 = torch.tensor(lambda_2)

    # Pas de temps
    dt = torch.tensor(dt)

    u_ = f.generate_vector(N)
    v_ = f.generate_vector(M)

    Y = f.generate_Y(N, M, u_,v_, lambda_)
    #print(Y)

    # Conditions initiales
    u_p = f.generate_vector(N)
    v_p = f.generate_vector(M)

    # Perform the gradient descent
    (res, res_mse) = main_gradient_2_sans_proj_normalisation(iteration, u_p, v_p, Y, u_, v_, N, M, lambda_, beta_u, beta_v, lambda_1, lambda_2, dt)
    
    return res, res_mse

In [None]:
# MAIN METHOD:

def main(lambda_u_val, lambda_v_val, size_of_one_sample, list_value_1, list_value_2):
    
    all_overlap_iterations=[]
    all_final_overlaps=[]
    
    all_mse_iterations=[]
    all_final_mse=[]
    
    num_iter=1
    tt_iter=len(list_value_1)*len(list_value_2)*size_of_one_sample
    
    for value_to_test_1 in list_value_1:
        
        for value_to_test_2 in list_value_2:
        
            one_run_overlap_iterations=[]
            one_run_mse_iterations=[]

            # Calculation
            for _ in range(size_of_one_sample):
                (current_uv_overlap,current_uv_mse)=one_run(lambda_1=lambda_u_val, lambda_2=lambda_v_val,
                                                    beta_u=value_to_test_1, beta_v=value_to_test_2)
                # Store the overlap at the end of each iterations:
                one_run_overlap_iterations.append(current_uv_overlap)
                # Store the mse at the end of each iterations:
                one_run_mse_iterations.append(current_uv_mse)
                
                print(f"progress {100*num_iter/tt_iter:.2f}%", end="\r")
                num_iter+=1

            # Store the overlap/mse at the final iteration:
            one_run_final_overlap = np.array([x[-1] for x in one_run_overlap_iterations])
            one_run_final_mse = np.array([x[-1] for x in one_run_mse_iterations])

            # Add the results of this specific run to the global array:
            # overlap at all iterations:
            all_overlap_iterations.append(one_run_overlap_iterations)
            # overlap at the final iteration:
            all_final_overlaps.append(one_run_final_overlap)
            # mse at all iterations:
            all_mse_iterations.append(one_run_mse_iterations)
            # mse at the final iteration:
            all_final_mse.append(one_run_final_mse)
            
    return np.array(all_overlap_iterations), np.array(all_final_overlaps), np.array(all_mse_iterations), np.array(all_final_mse)

In [None]:
# SELECT THE RANGES:

name_value_1="beta_u"
name_value_2="beta_v"

#list_value_1 = np.logspace(-1,0.5, 30)
#list_value_1 = np.arange(0.1, 3, 0.1)
list_value_1 = np.array([0.5,1,2])

#list_value_2 = np.logspace(-1,0.5, 30)
#list_value_2 = np.arange(0.1, 3, 0.1)
list_value_2 = np.array([0.5,1,2])

In [None]:
# FIRST RUN:

# Enter the configuration of N and M
# Default n=m=500
name_n_1="500"
name_m_1="500"

# Choose the desired lambda_1 and lambda_2
lambda_u_val_1 = 1
lambda_v_val_1 = 1

size_of_one_sample=10

(overlap_at_each_iterations, overlap_at_final_iteration,
 mse_at_each_iterations, mse_at_final_iteration) = main(lambda_u_val_1, lambda_v_val_1,
                                            size_of_one_sample, list_value_1, list_value_2)

In [None]:
# SECOND RUN

# Enter the configuration of N and M
# Default n=m=500
name_n_2="500"
name_m_2="500"

# Choose the desired lambda_1 and lambda_2
lambda_u_val_2 = 1
lambda_v_val_2 = 0.1

size_of_one_sample=10

(overlap_at_each_iterations_2, overlap_at_final_iteration_2,
 mse_at_each_iterations_2, mse_at_final_iteration_2) = main(lambda_u_val_2, lambda_v_val_2,
                                                        size_of_one_sample, list_value_1, list_value_2)

In [None]:
# If you want to save the ploty graph,
# enter the desired file path in the following variable:
# example: FILE_PATH="C:/Users/Admin/Desktop/"
FILE_PATH="..."
# And you need to uncomment the last line of the cell that are containing plotly graph:
# This line: fig.write_html("... .html")

In [None]:
# REMINDER:
# You can double-click on an item in the legend,
# to display only the specific item,
# Performing one click will just hide the item

# 3D PLOT OVERLAP

x, y = np.meshgrid(list_value_2,list_value_1)
overlap_mean=overlap_at_final_iteration.mean(axis=1).mean(axis=1)
overlap_mean=overlap_mean.reshape(len(list_value_1),len(list_value_2))

overlap_mean_2=overlap_at_final_iteration_2.mean(axis=1).mean(axis=1)
overlap_mean_2=overlap_mean_2.reshape(len(list_value_1),len(list_value_2))

fig = go.Figure(data=[
    go.Surface(name=f"n={name_n_1}, m={name_m_1}; lu={lambda_u_val_1}, lv={lambda_v_val_1}",
               z=overlap_mean,x=list_value_1, y=list_value_2,
               colorscale ='Blues', showscale=False, opacity=1),
    go.Surface(name=f"n={name_n_2}, m={name_m_2}; lu={lambda_u_val_2}, lv={lambda_v_val_2}",
               z=overlap_mean_2,x=list_value_1, y=list_value_2,
               colorscale ='Greens', showscale=False, opacity=0.9)])

fig.update_layout(
    title=f"OVERLAP:",
)

fig.update_layout(showlegend=True, scene = dict(
                    xaxis_title=name_value_2,
                    yaxis_title=name_value_1,
                    zaxis_title="overlap"))
fig.update_traces(showlegend=True)

fig.show()
path=f"{FILE_PATH}R1_n={name_n_1}_m={name_m_1}_"\
f"Ldiff={lambda_u_val_1/lambda_v_val_1}_"\
f"R2_n={name_n_2}_m={name_m_2}_"\
f"Ldiff={lambda_u_val_2/lambda_v_val_2}_overlap.html"

#fig.write_html(path)

In [None]:
# 3D PLOT MSE

x, y = np.meshgrid(list_value_2,list_value_1)
mse_mean_1, mse_mean_2, mse_mean_3=mse_at_final_iteration.mean(axis=1).T
mse_mean_1=mse_mean_1.reshape(len(list_value_1),len(list_value_2))
mse_mean_2=mse_mean_2.reshape(len(list_value_1),len(list_value_2))
mse_mean_3=mse_mean_3.reshape(len(list_value_1),len(list_value_2))

mse_mean_4, mse_mean_5, mse_mean_6=mse_at_final_iteration_2.mean(axis=1).T
mse_mean_4=mse_mean_4.reshape(len(list_value_1),len(list_value_2))
mse_mean_5=mse_mean_5.reshape(len(list_value_1),len(list_value_2))
mse_mean_6=mse_mean_6.reshape(len(list_value_1),len(list_value_2))

fig = go.Figure(data=[
    go.Surface(name=f"mse1 n={name_n_1}, m={name_m_1}; lu={lambda_u_val_1}, lv={lambda_v_val_1}",
               z=mse_mean_1,x=list_value_1, y=list_value_2,
               colorscale ='solar', showscale=False, opacity=1),
    go.Surface(name=f"mse1 n={name_n_2}, m={name_m_2}; lu={lambda_u_val_2}, lv={lambda_v_val_2}",
               z=mse_mean_4,x=list_value_1, y=list_value_2,
               colorscale ='ice', showscale=False, opacity=0.9),
    go.Surface(name=f"mse2 n={name_n_1}, m={name_m_1}; lu={lambda_u_val_1}, lv={lambda_v_val_1}",
               z=mse_mean_2,x=list_value_1, y=list_value_2,
               colorscale ='solar', showscale=False, opacity=1),
    go.Surface(name=f"mse2 n={name_n_2}, m={name_m_2}; lu={lambda_u_val_2}, lv={lambda_v_val_2}",
               z=mse_mean_5,x=list_value_1, y=list_value_2,
               colorscale ='ice', showscale=False, opacity=0.9),
    go.Surface(name=f"mse3 n={name_n_1}, m={name_m_1}; lu={lambda_u_val_1}, lv={lambda_v_val_1}",
               z=mse_mean_3,x=list_value_1, y=list_value_2,
               colorscale ='solar', showscale=False, opacity=1),
    go.Surface(name=f"mse3 n={name_n_2}, m={name_m_2}; lu={lambda_u_val_2}, lv={lambda_v_val_2}",
               z=mse_mean_6,x=list_value_1, y=list_value_2,
               colorscale ='ice', showscale=False, opacity=0.9)])

fig.update_layout(
    title=f"MSE:",
)
fig.update_layout(scene = dict(
                    xaxis_title=name_value_2,
                    yaxis_title=name_value_1,
                    zaxis_title="mse"))
fig.update_traces(showlegend=True)

fig.show()
path=f"{FILE_PATH}R1_n={name_n_1}_m={name_m_1}_"\
f"Ldiff={lambda_u_val_1/lambda_v_val_1}_"\
f"R2_n={name_n_2}_m={name_m_2}_"\
f"Ldiff={lambda_u_val_2/lambda_v_val_2}_mse.html"

#fig.write_html(path)

In [None]:
#needed:
from plotly.subplots import make_subplots
# This graph display the overlap of u and v separately 
# for the SECOND run

# We always run the first config with lambda_1=lambda_2=1
# So it was more interesting to check for the second config as
# we put different lambda_1 and lambda_2

# but if you want for the first config, you need to make the following changes:
# overlap_at_final_iteration_2 to overlap_at_final_iteration
# name_n_2 to name_n_1
# name_m_2 to name_m_1
# lambda_u_val_2 to lambda_u_val_1
# lambda_v_val_2 to lambda_v_val_1

array_u_overlap, array_v_overlap=overlap_at_final_iteration_2.mean(axis=1).T

array_u_overlap=array_u_overlap.reshape(len(list_value_1),len(list_value_2))
array_v_overlap=array_v_overlap.reshape(len(list_value_1),len(list_value_2))

fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'surface'}]])
fig.add_trace(go.Surface(z=array_u_overlap, x=list_value_2, y=list_value_1,
                         name="overlap_u", showscale=False),
             row=1, col=1)
fig.add_trace(go.Surface(z=array_v_overlap, x=list_value_2, y=list_value_1,
                         name="overlap_v", showscale=False, opacity=0.9),
             row=1, col=1)

fig.update_layout(
    title=f"n={name_n_2}, m={name_m_2}; lu={lambda_u_val_2}, lv={lambda_v_val_2}",
)
fig.update_layout(scene = dict(
                    xaxis_title=name_value_2,
                    yaxis_title=name_value_1,
                    zaxis_title="overlap"),
                  scene2 = dict(
                    xaxis_title=name_value_2,
                    yaxis_title=name_value_1,
                    zaxis_title="mse"),
               autosize=False, width=1000, height=500)
fig.update_traces(showlegend=True)

fig.show()
path=f"{FILE_PATH}n={name_n_2}_m={name_m_2}_lu={lambda_u_val_2}_lv={lambda_v_val_2}_uv_overlap.html"
#fig.write_html(path)

In [None]:
# stabilisation overlap for the SECOND RUN
# CHANGER SELECTED_RUN 
selected_run=0

var_u,var_v=np.var(overlap_at_each_iterations_2[selected_run], axis=0).T
overlap_u,overlap_v=overlap_at_each_iterations_2.mean(1)[selected_run].T

x=[i for i in range(len(overlap_u))]

fig=go.Figure(data=[go.Scatter(
        x=x,
        y=overlap_u,
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_u,
            visible=True),
        name="u overlap"
    ),
      go.Scatter(
        x=x,
        y=overlap_v,
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_v,
            visible=True),
        name="v overlap"     
      )])

fig.update_layout(
    title="Stabilisation overlap pour: "+f"{name_value_1}={list_value_1[selected_run//len(list_value_2)]}"
                    + f" & {name_value_2}={list_value_2[selected_run%len(list_value_2)]}",
    xaxis_title="Iterations",
    yaxis_title="Overlap",
)

fig.show()
path=f"{FILE_PATH}n={name_n_2}_m={name_m_2}_lu={lambda_u_val_2}_lv={lambda_v_val_2}_run={selected_run}.html"
#fig.write_html(path)

In [None]:
# stabilisation mse and overlap for the SECOND RUN

# REMINDER:
# You can double-click on an item in the legend,
# to display only the specific item,
# Performing one click will just hide the item

selected_run=0

var_overlap = np.var(overlap_at_each_iterations_2[selected_run].mean(axis=2), axis=0)

var_m1,var_m2, var_m3=np.var(mse_at_each_iterations_2[selected_run], axis=0).T
mse1,mse2,mse3=mse_at_each_iterations_2.mean(1)[selected_run].T

x=[i for i in range(len(mse1))]

fig=go.Figure(data=[go.Scatter(
        x=x,
        y=mse1,
        name="mse1",
        error_y=dict(
            type='data',
            array=var_m1,
            visible=True)
    ),
      go.Scatter(
        x=x,
        y=mse2,
        name="mse2",    
        error_y=dict(
            type='data',
            array=var_m2,
            visible=True)
      ),
      go.Scatter(
        x=x,
        y=mse3,
        name="mse3",
        error_y=dict(
            type='data',
            array=var_m3,
            visible=True)
      ),
      go.Scatter(
        x=x,
        y=overlap_at_each_iterations_2.mean(1).mean(2)[selected_run],
        name="overlap",
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_overlap,
            visible=True)
    )])

fig.update_layout(
    title=f"{name_value_1}={list_value_1[selected_run//len(list_value_2)]}"
            + f" & {name_value_2}={list_value_2[selected_run%len(list_value_2)]}",
    xaxis_title="Iterations",
)

fig.show()
#fig.write_html(f"{FILE_PATH}stabilisation_mse_overlap.html")