In [8]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import PROJECT.util as f
from matplotlib.pyplot import cm
from mpl_toolkits import mplot3d
import plotly.graph_objects as go
import pandas as pd

In [9]:
# Pour des vecteurs u* et v*,
# les overlaps obtenus sont soit les deux positifs ou les deux négatifs, on a:
# overlap(u*)= -x et overlap(v*)= -y ou overlap(u*)= x et overlap(v*)= y.
# on prend donc le min(u-x, u+x):

# si u est positif et x est positif:
# min(u-x, u+x) -> u-x est bon

# si u est positif et x est négatif:
# min(u-(-x), u+(-x)) -> u+(-x) =u-x est bon

# si u est négatif et x est positif:
# min(-u-x, -u+x) -> -u+x =x-u est bon

# si u est négatif et x est négatif:
# min(-u-(-x), -u+(-x)) -> -u-(-x) =x-u est bon

def mse1(u,u_n):
    ## u-u_n
    a=torch.pow((u-u_n),2).mean(0)
    b=torch.pow((u+u_n),2).mean(0)
    return torch.min(a,b)
def mse2(u_e, v_e, u_t, v_t):
    ## u*v*-uv
    tmp = torch.tensordot(u_e, v_e, 0) - torch.tensordot(u_t, v_t, 0)
    return torch.sum(torch.pow((tmp), 2))
def mse3(u_e, v_e, u_t, v_t, Y):
    ## Y-uv
    tmp = Y - torch.tensordot(u_t, v_t, 0)
    return torch.sum(torch.pow((tmp), 2))

In [10]:
def main_gradient_2_sans_proj_normalisation(iteration, u_p, v_p, Y, u_, v_, N, M, lambda_, beta_u, beta_v, lambda_1, lambda_2, dt):

    res = []
    res_mse= []
    
    mod = iteration*0.1
    
    stabilisation_check=np.empty((0,2))

    for i in range(iteration):

        # Computation
        sqrt_dt = torch.sqrt(dt)

        u_1 = (1/lambda_1) * f.gradient_u_2(N, M, u_p,v_p,Y,lambda_) * dt
        u_2 = torch.sqrt(2/(lambda_1*beta_u)) * torch.empty(N).normal_(mean=0,std=sqrt_dt)
        u_3 = ((N-1)/(N*lambda_1*beta_u))*u_p*dt
        u_n = u_p - u_1 + u_2 - u_3

        v_1 = 1/lambda_2 * f.gradient_v_2(N,M,u_p,v_p,Y,lambda_) * dt
        v_2 = torch.sqrt(2/(lambda_2*beta_v)) * torch.empty(M).normal_(mean=0,std=sqrt_dt)
        v_3 = ((M-1)/(M*lambda_2*beta_v))*v_p*dt
        v_n = v_p - v_1 + v_2 - v_3

        # Normalisation
        u_n = u_n / torch.linalg.norm(u_n)
        v_n = v_n / torch.linalg.norm(v_n)
        u_n = u_n * torch.sqrt(N)
        v_n = v_n * torch.sqrt(M)

        # Re-asign for the loop
        u_p = u_n
        v_p = v_n
    
        res_u = torch.abs(f.overlap(u_,u_n,N))
        res_v = torch.abs(f.overlap(v_,v_n,M))
        res.append((res_u,res_v))
        
        res_mse.append(mse2(u_,v_,u_n,v_n))
        
        if len(stabilisation_check)==0:
            #print("INIT")
            stabilisation_check=np.append(stabilisation_check, [[res_u,res_v]], axis=0)
        elif len(stabilisation_check)==200:
            #print("FINISH")
            print(f"overlap(u, v): ({res_u}; {res_v})")
            return np.array(res), np.array(res_mse)
        elif (np.abs(stabilisation_check-[res_u,res_v])<0.0001).all():
            #print("ADD")
            stabilisation_check=np.append(stabilisation_check, [[res_u,res_v]], axis=0)
        else:
            #print("RESET")
            stabilisation_check=np.empty((0,2))
        
        if i%mod==0:
            print(f"progress {100*i/iteration:.2f}%", end="\r")
        if i==iteration-1:
            print(f"overlap(u, v): ({res_u}; {res_v})")

    return np.array(res), np.array(res_mse)

In [11]:
def one_run(lambda_1,lambda_2, N, M,lambda_=2, beta_u=float("inf"), 
            beta_v=float("inf"),  dt=1/100, iteration=1000):
      #Size of the matrix Y
    N = torch.tensor(N)
    M = torch.tensor(M)

    # lambda
    lambda_ = torch.tensor(lambda_)

    #temperatures
    beta_u = torch.tensor(beta_u)
    beta_v = torch.tensor(beta_v)

    #learning rates
    lambda_1 = torch.tensor(lambda_1)
    lambda_2 = torch.tensor(lambda_2)

    # Pas de temps
    dt = torch.tensor(dt)

    u_ = f.generate_vector(N)
    v_ = f.generate_vector(M)

    Y = f.generate_Y(N, M, u_,v_, lambda_)
    #print(Y)

    # Conditions initiales
    u_p = f.generate_vector(N)
    v_p = f.generate_vector(M)

    # Perform the gradient descent
    (res, res_mse) = main_gradient_2_sans_proj_normalisation(iteration, u_p, v_p, Y, u_, v_, N, M, lambda_, beta_u, beta_v, lambda_1, lambda_2, dt)
    
    return res, res_mse


In [12]:
# MAIN METHOD:

def main(size_of_one_sample, list_value_1, list_value_2,lambda_u,lambda_v,d_u,d_v):
    
    all_overlap_iterations=[]
    all_final_overlaps=[]
    
    all_mse_iterations=[]
    all_final_mse=[]
    
    for value_to_test_1 in list_value_1:
        
        for value_to_test_2 in list_value_2:
        
            one_run_overlap_iterations=[]
            one_run_mse_iterations=[]

            # Calculation
            for _ in range(size_of_one_sample):
                (current_uv_overlap,current_uv_mse)=one_run(lambda_1 = lambda_u, lambda_2 = lambda_v, N = d_u, M = d_v, beta_u=value_to_test_1, beta_v=value_to_test_2)
                # Store the overlap at the end of each iterations:
                one_run_overlap_iterations.append(current_uv_overlap)
                # Store the mse at the end of each iterations:
                one_run_mse_iterations.append(current_uv_mse)

            # Store the overlap/mse at the final iteration:
            one_run_final_overlap = np.array([x[-1] for x in one_run_overlap_iterations])
            one_run_final_mse = np.array([x[-1] for x in one_run_mse_iterations])

            # Add the results of this specific run to the global array:
            # overlap at all iterations:
            all_overlap_iterations.append(one_run_overlap_iterations)
            # overlap at the final iteration:
            all_final_overlaps.append(one_run_final_overlap)
            # mse at all iterations:
            all_mse_iterations.append(one_run_mse_iterations)
            # mse at the final iteration:
            all_final_mse.append(one_run_final_mse)
            
    return np.array(all_overlap_iterations), np.array(all_final_overlaps), np.array(all_mse_iterations), np.array(all_final_mse)

In [13]:
# MAIN CELL: DO THE COMPUTATION

list_value = np.logspace(-1,0.5,num=10)

lambda_u = 1
lambda_v = 1

d_u = 500
d_v = 500

size_of_one_sample=10

(overlap_at_each_iterations, overlap_at_final_iteration,
 mse_at_each_iterations, mse_at_final_iteration) = main(size_of_one_sample, list_value_1, list_value_2,lambda_u_1,lambda_v_1,d_u_1,d_v_1)

overlap(u, v): (0.06267625093460083; 0.008093798533082008)
overlap(u, v): (0.05526364967226982; 0.027307678014039993)
overlap(u, v): (0.004527139477431774; 0.05152348428964615)
overlap(u, v): (0.08182481676340103; 0.0015619087498635054)
overlap(u, v): (0.02196245640516281; 0.003027210244908929)
overlap(u, v): (0.12431563436985016; 0.09073515981435776)
overlap(u, v): (0.06996233761310577; 0.01603589951992035)
overlap(u, v): (0.06643187254667282; 0.013614795170724392)
overlap(u, v): (0.017817258834838867; 0.013134319335222244)
overlap(u, v): (0.15326853096485138; 0.03809739276766777)
overlap(u, v): (0.1716625839471817; 0.014546971768140793)
overlap(u, v): (0.19898395240306854; 0.011704301461577415)
overlap(u, v): (0.13106606900691986; 0.019911106675863266)
overlap(u, v): (0.0593767911195755; 0.03627850487828255)
overlap(u, v): (0.08842243999242783; 0.014955433085560799)
overlap(u, v): (0.0013133191969245672; 0.034393806010484695)
overlap(u, v): (0.18540330231189728; 0.0022715111263096333

overlap(u, v): (0.5503881573677063; 0.7177531719207764)
overlap(u, v): (0.6568652391433716; 0.7031650543212891)
overlap(u, v): (0.6347490549087524; 0.6732839345932007)
overlap(u, v): (0.5477347373962402; 0.711808443069458)
overlap(u, v): (0.6600244641304016; 0.7099353671073914)
overlap(u, v): (0.5284516215324402; 0.6739126443862915)
overlap(u, v): (0.6148694157600403; 0.7256671786308289)
overlap(u, v): (0.6354987025260925; 0.7545617818832397)
overlap(u, v): (0.5880118012428284; 0.738242506980896)
overlap(u, v): (0.6730561852455139; 0.7295853495597839)
overlap(u, v): (0.6030410528182983; 0.7650486826896667)
overlap(u, v): (0.6288928389549255; 0.7253687977790833)
overlap(u, v): (0.6910760998725891; 0.7321600317955017)
overlap(u, v): (0.6941534280776978; 0.7441135048866272)
overlap(u, v): (0.5644233822822571; 0.7112429738044739)
overlap(u, v): (0.5607684254646301; 0.7106209993362427)
overlap(u, v): (0.025589004158973694; 0.057163964956998825)
overlap(u, v): (0.03681546077132225; 0.0303260

overlap(u, v): (0.037422653287649155; 0.015331520698964596)
overlap(u, v): (0.059599585831165314; 0.020814334973692894)
overlap(u, v): (0.08301646262407303; 0.003151834476739168)
overlap(u, v): (0.023225631564855576; 0.019207175821065903)
overlap(u, v): (0.07728777080774307; 0.05631203576922417)
overlap(u, v): (0.035733796656131744; 0.0734095424413681)
overlap(u, v): (0.21167892217636108; 0.037800949066877365)
overlap(u, v): (0.03254260495305061; 0.01890428736805916)
overlap(u, v): (0.12494294345378876; 0.006659860257059336)
overlap(u, v): (0.1427566558122635; 0.10593240708112717)
overlap(u, v): (0.12049403786659241; 0.0355578288435936)
overlap(u, v): (0.11407054960727692; 0.03181995451450348)
overlap(u, v): (0.21099989116191864; 0.04520990699529648)
overlap(u, v): (0.07951650768518448; 0.18993692100048065)
overlap(u, v): (0.1500283032655716; 0.029871666803956032)
overlap(u, v): (0.10167527943849564; 0.10534817725419998)
overlap(u, v): (0.3842911422252655; 0.2536233961582184)
overlap(u

overlap(u, v): (0.7955853343009949; 0.6492089629173279)
overlap(u, v): (0.7987297773361206; 0.665155827999115)
overlap(u, v): (0.7366344332695007; 0.6648614406585693)
overlap(u, v): (0.7914207577705383; 0.6841376423835754)
overlap(u, v): (0.8393592834472656; 0.6702287793159485)
overlap(u, v): (0.7539915442466736; 0.668957531452179)
overlap(u, v): (0.7761979699134827; 0.6655551791191101)
overlap(u, v): (0.7679585218429565; 0.6677607893943787)
overlap(u, v): (0.8071796298027039; 0.7009316682815552)
overlap(u, v): (0.793138325214386; 0.7009156942367554)
overlap(u, v): (0.7826755046844482; 0.7160999774932861)
overlap(u, v): (0.8610733151435852; 0.7151532173156738)
overlap(u, v): (0.8275955319404602; 0.7332621812820435)
overlap(u, v): (0.8223005533218384; 0.7190006375312805)
overlap(u, v): (0.7865815758705139; 0.7027533054351807)
overlap(u, v): (0.773414134979248; 0.7039251923561096)
overlap(u, v): (0.7844393849372864; 0.6830583214759827)
overlap(u, v): (0.7968440055847168; 0.73082089424133

overlap(u, v): (0.1866217851638794; 0.03259605914354324)
overlap(u, v): (0.008452597074210644; 0.08134803920984268)
overlap(u, v): (0.16207681596279144; 0.054952215403318405)
overlap(u, v): (0.07807401567697525; 0.05471896380186081)
overlap(u, v): (0.21709293127059937; 0.006668101530522108)
overlap(u, v): (0.010067424736917019; 0.012740730307996273)
overlap(u, v): (0.04765971004962921; 0.10229134559631348)
overlap(u, v): (0.004901409149169922; 0.06845011562108994)
overlap(u, v): (0.10117901116609573; 0.059600237756967545)
overlap(u, v): (0.03996051847934723; 0.009870190173387527)
overlap(u, v): (0.21131214499473572; 0.0073810298927128315)
overlap(u, v): (0.16134290397167206; 0.03435544669628143)
overlap(u, v): (0.00144763826392591; 0.008907816372811794)
overlap(u, v): (0.050475072115659714; 0.03637870401144028)
overlap(u, v): (0.08458032459020615; 0.009208817034959793)
overlap(u, v): (0.10080762952566147; 0.022746030241250992)
overlap(u, v): (0.11204423010349274; 0.08089116960763931)
o

overlap(u, v): (0.675761878490448; 0.4635940492153168)
overlap(u, v): (0.8320350050926208; 0.6387366652488708)
overlap(u, v): (0.8503854870796204; 0.6158663034439087)
overlap(u, v): (0.8585063815116882; 0.6703088283538818)
overlap(u, v): (0.8667343854904175; 0.6302052736282349)
overlap(u, v): (0.8379424810409546; 0.6349532604217529)
overlap(u, v): (0.8518525958061218; 0.6164626479148865)
overlap(u, v): (0.8672508001327515; 0.6393494606018066)
overlap(u, v): (0.8409655690193176; 0.651603639125824)
overlap(u, v): (0.8374025225639343; 0.6320563554763794)
overlap(u, v): (0.8595813512802124; 0.6439099311828613)
overlap(u, v): (0.8835797309875488; 0.675708532333374)
overlap(u, v): (0.8550180196762085; 0.6813857555389404)
overlap(u, v): (0.8846477270126343; 0.6844542026519775)
overlap(u, v): (0.862277090549469; 0.6724746227264404)
overlap(u, v): (0.8526833057403564; 0.6645210981369019)
overlap(u, v): (0.8667455315589905; 0.6766564846038818)
overlap(u, v): (0.865285336971283; 0.692765355110168

In [None]:
name_value_u="beta_u"
name_value_v="beta_v"
name_d_u = str(d_u)
name_d_v = str(d_v)

FILE_PATH="/Users/constantindebentzmann/Desktop/plots/"

In [None]:
#### MSE 3D PLOT BOTH CONFIG
x, y = np.meshgrid(list_value_2,list_value_1)
mse_mean=mse_at_final_iteration.mean(axis=1).mean(axis=1)
mse_mean=mse_mean.reshape(len(list_value_1),len(list_value_2))

mse_mean_2=mse_at_final_iteration_2.mean(axis=1).mean(axis=1)
mse_mean_2=mse_mean_2.reshape(len(list_value_1),len(list_value_2))

fig = go.Figure(data=[
    go.Surface(name=f"n={d_u_1}, m={d_v_1}; lu={lambda_u_1}, lv={lambda_v_1}",
               z=mse_mean,x=list_value_1, y=list_value_2,
               colorscale ='Blues', showscale=False, opacity=1),
    go.Surface(name=f"n={d_u_2}, m={d_v_2}; lu={lambda_u_2}, lv={lambda_v_2}",
               z=mse_mean_2,x=list_value_1, y=list_value_2,
               colorscale ='Greens', showscale=False, opacity=0.9)])

fig.update_layout(
    title=f"MSE:",
)
fig.update_layout(scene = dict(
                    xaxis_title=name_value_u,
                    yaxis_title=name_value_v,
                    zaxis_title="mse"))
fig.update_traces(showlegend=True)

fig.show()
path=f"{FILE_PATH}R1_n={d_u_1}_m={d_v_1}_"\
f"Ldiff={lambda_u_1/lambda_v_1}_"\
f"R2_n={d_u_2}_m={d_v_2}_"\
f"Ldiff={lambda_u_2/lambda_v_2}_mse.html"
fig.write_html(path)

In [None]:
#### MSE in function fo iteration
from scipy.signal import savgol_filter
mse_data = mse_at_each_iterations.mean(axis=3).mean(axis=1)
plt.figure(figsize=((10,8)))
for i in range(size_of_one_sample**2):
    plt.plot(savgol_filter(mse_data[i], 51, 3))
plt.ylabel("mse")
plt.xlabel("iterations")
path_plot=f"{FILE_PATH}R1_n={d_u_1}_m={d_v_1}_"\
f"Ldiff={lambda_u_1/lambda_v_1}_"\
f"R2_n={d_u_2}_m={d_v_2}_"\
f"Ldiff={lambda_u_2/lambda_v_2}_mse_plot.png"
plt.savefig(path_plot)
plt.show()


In [None]:
# STABILISATION MSE AVEC PLOTLY
# CHANGER SELECTED_RUN !
selected_run=0

var_u,var_v=np.var(mse_at_each_iterations[selected_run], axis=0).T
mse_u,mse_v=mse_at_each_iterations.mean(1)[selected_run].T

x=[i for i in range(len(mse_u))]

fig=go.Figure(data=[go.Scatter(
        x=x,
        y=mse_u,
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_u,
            visible=True),
        name="u mse"
    ),
      go.Scatter(
        x=x,
        y=mse_v,
        error_y=dict(
            type='data', # value of error bar given in data coordinates
            array=var_v,
            visible=True),
        name="v mse"     
      )])

fig.update_layout(
    title="Stabilisation mse pour: "+f"{name_value_u}={list_value_1[selected_run//len(list_value_2)]}"
                    + f" & {name_value_v}={list_value_2[selected_run%len(list_value_2)]}",
    xaxis_title="Iterations",
    yaxis_title="MSE",
)

fig.show()

In [None]:
# GRAPH STABILISATION DES OVERLAPS
fig, axs = plt.subplots(len(list_value_1)*len(list_value_2), figsize=(20,200), facecolor='white')
fig.tight_layout()

for i in range(len(list_value_1)*len(list_value_2)):
    axs[i].grid()
    axs[i].set_ylabel("overlap")
    axs[i].set_title(f"{name_value_u}={list_value_1[i//len(list_value_2)]}"
                    + f" & {name_value_v}={list_value_2[i%len(list_value_2)]}")
    #for run_in_on_sample in range(size_of_one_sample):
    axs[i].plot(overlap_at_each_iterations.mean(1)[i])