In [None]:
import torch
import matplotlib.pyplot as plt
import non_local_boxes
import numpy as np
from IPython.display import clear_output   # in order to clear the print output
import time

# Sugar coating for reloading
%matplotlib inline
%load_ext autoreload
%autoreload 2

# from IPython.display import set_matplotlib_formats
# set_matplotlib_formats('svg')   # in ordert to have unblurred pictures
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

# Gradient Descent

In [None]:
M1 = torch.zeros(32, 32)
for i in range(32):
    M1[i,i]=1
M1[0,0]=0.5
M1[0,1]=0.5
M1[1,0]=0.5
M1[1,1]=0.5

M2 = torch.zeros(32, 32)
for i in range(32):
    M2[i,i]=1
M2[8,8]=0.5
M2[8,9]=0.5
M2[9,8]=0.5
M2[9,9]=0.5

M3 = torch.zeros(32, 32)
for i in range(32):
    M3[i,i]=1
M3[2,2]=0.5
M3[2,3]=0.5
M3[3,2]=0.5
M3[3,3]=0.5

M4 = torch.zeros(32, 32)
for i in range(32):
    M4[i,i]=1
M4[10,10]=0.5
M4[10,11]=0.5
M4[11,10]=0.5
M4[11,11]=0.5

M5 = torch.zeros(32, 32)
for i in range(32):
    M5[i,i]=1
M5[4,4]=0.5
M5[4,5]=0.5
M5[5,4]=0.5
M5[5,5]=0.5

M6 = torch.zeros(32, 32)
for i in range(32):
    M6[i,i]=1
M6[12,12]=0.5
M6[12,13]=0.5
M6[13,12]=0.5
M6[13,13]=0.5

M7 = torch.zeros(32, 32)
for i in range(32):
    M7[i,i]=1
M7[6,6]=0.5
M7[6,7]=0.5
M7[7,6]=0.5
M7[7,7]=0.5

M8 = torch.zeros(32, 32)
for i in range(32):
    M8[i,i]=1
M8[14,14]=0.5
M8[14,15]=0.5
M8[15,14]=0.5
M8[15,15]=0.5

In [None]:
def projected_wiring(W):  # W is a 32xn tensor
    W = torch.maximum(W, torch.zeros_like(W))  # it outputs the element-wise maximum
    W = torch.minimum(W, torch.ones_like(W))   # similarly for minimum

    T1 = (torch.abs(W[0,:]-W[1,:]) <= torch.abs(W[8, :] - W[9, :]))
    W = T1*torch.tensordot(M1, W, dims=1) + torch.logical_not(T1)*torch.tensordot(M2, W, dims=1)
    
    T2 = (torch.abs(W[2,:]-W[3,:]) <= torch.abs(W[10, :] - W[11, :]))
    W = T2*torch.tensordot(M3, W, dims=1) + torch.logical_not(T2)*torch.tensordot(M4, W, dims=1)

    T3 = (torch.abs(W[4,:]-W[5,:]) <= torch.abs(W[12, :] - W[13, :]))
    W = T3*torch.tensordot(M5, W, dims=1) + torch.logical_not(T3)*torch.tensordot(M6, W, dims=1)

    T4 = (torch.abs(W[6,:]-W[7,:]) <= torch.abs(W[14, :] - W[15, :]))
    W = T4*torch.tensordot(M7, W, dims=1) + torch.logical_not(T4)*torch.tensordot(M8, W, dims=1)

    return W

In [None]:
def gradient_descent(starting_W, P, Q, learning_rate, nb_iterations = 400, tolerance=1e-6):
    m = non_local_boxes.evaluate.nb_columns
    external_grad = torch.ones(m)
    W = starting_W
    for _ in range(nb_iterations):
        Wold = W
        non_local_boxes.evaluate.phi_flat(W, P, Q).backward(gradient=external_grad)
        W = projected_wiring(W + learning_rate*W.grad).detach() 
        if (torch.max(torch.abs(W-Wold)) < tolerance):   return W
        W.requires_grad=True
    return W

### Histogram

In [None]:
PR = non_local_boxes.utils.PR
SR = non_local_boxes.utils.SR
I = non_local_boxes.utils.I

In [None]:
p=0.39
q=0.6
P = p*PR +q*SR + (1-p-q)*I
BoxProduct = non_local_boxes.evaluate.phi_flat

m = non_local_boxes.evaluate.nb_columns
alpha = 0.01
K=int(1e4)
epsilon=1e-6

W = gradient_descent(
    starting_W=non_local_boxes.utils.random_wiring(m),
    P=P,
    Q=P,
    learning_rate=alpha,
    nb_iterations=K,
    tolerance=epsilon
)
histogramGD = BoxProduct(W, P, P).tolist()

#plt.hist(histogramGD, bins=50, label="Gradient Descent (p="+str(p)+", q="+str(q)+", α="+str(alpha)+", K="+str(K)+", m=10^"+str(int(np.log10(m)))+", ε=10^"+str(int(np.log10(epsilon)))+")")
plt.hist(histogramGD, bins=50, label="Gradient Descent (α="+str(alpha)+", K=10^"+str(int(np.log10(K)))+", ε=10^"+str(int(np.log10(epsilon)))+", m=10^"+str(int(np.log10(m)))+")")
#plt.xlabel("CHSH-value")
plt.xlabel("$\Phi(\mathsf{W}_{{out}})$")
plt.ylabel("Number of reruns")
plt.yscale("log")
plt.legend()
#plt.title("Histogram of the different results with a random initialization (with $\mathbf{P}=(\mathbf{PR}+\mathbf{SR})/2$, total: "+str(N)+" occurences)")
plt.show()

In [None]:
for j in range(len(histogramGD)):
    print(str(histogramGD[j])+",")

# Line Search

$$
\left\{
\begin{array}{l}
    \alpha^*_k = \argmax_\alpha \phi(x_k + \alpha \nabla \phi(x_k))\\
    x_{k+1}= \texttt{proj}(x_k + \alpha^*_k \nabla \phi(x_k))
\end{array}
\right.
$$

In [None]:
def reorder_list(L, phi):
    j=0
    while j<len(L):
        if j!=0 and phi[L[j-1]]<phi[L[j]]:
            L[j-1],L[j]=L[j],L[j-1]
            j-=2
        j+=1
    return L

In [None]:
phi=[0.1, 0.3, 0, 10, 9, 0.5]
L = [*range(len(phi))]
L = reorder_list(L, phi)
print([phi[k] for k in L])

In [None]:
def select_best_columns(W, P, Q, integer):
    if integer==0: return non_local_boxes.utils.random_wiring(m).detach()
    # L is the list of the "best" indexes of the columns of W
    # At the begining, we take the first indexes of W
    # We will change the list L by comparing the value at the other indexes
    # When we add a term to L, we also remove the "worst" one, and we re-order the list L
    L = [*range(integer)]
    # phi is the list of values:
    phi= non_local_boxes.evaluate.phi_flat(W,P,Q).tolist()
    # we re-order the list L:
    L = reorder_list(L, phi)
    for i in range(integer,non_local_boxes.evaluate.nb_columns):
        if phi[i]>phi[L[-1]]:
            L[-1]=i # we remove and replace the worst index
            L = reorder_list(L, phi)

    W_new = non_local_boxes.utils.random_wiring(m).detach()
    for k in range(integer):
        W_new[:,L[k]] = W[:,L[k]] # we keep only the best ones

    return W_new

In [None]:
def line_search_with_resets(P, Q, LS_iterations, K_reset, chi):
    m = non_local_boxes.evaluate.nb_columns
    phi_flat = non_local_boxes.evaluate.phi_flat
    #W = non_local_boxes.utils.random_wiring(m)
    W = torch.zeros(32,m)
    external_grad = torch.ones(m)
    Krange=range(K_reset)
    LSrange=range(LS_iterations)
    
    for j in range(0,int(1/chi)):
        # Reset:
        W = select_best_columns(W, P, Q, min(m, int(j*m*chi))).detach()
        W.requires_grad=True

        # At the end, we do a lot of steps:
        if j==int(1/chi)-1: Krange=range(10*K_reset)

        # Line search:
        for _ in Krange:
            phi_flat(W, P, Q).backward(gradient=external_grad)
            gradient=W.grad
            alpha = torch.ones(m)*0.01
            for _ in LSrange:
                Gains = phi_flat(W, P, Q)
                Gains_new = phi_flat(W + alpha*gradient, P, Q)
                mask = 0.0 + (Gains>Gains_new)
                alpha = 0.5*mask*alpha + 1.7*(1-mask)*alpha
            W = projected_wiring(W + alpha*gradient).detach()
            W.requires_grad=True

    return W

In [None]:
p=0.39
q=0.6
P = p*PR +q*SR + (1-p-q)*I
BoxProduct = non_local_boxes.evaluate.phi_flat

m = non_local_boxes.evaluate.nb_columns
K_reset=100
chi = 0.0003
LS_iterations = 7

W=line_search_with_resets(
    P, 
    P, 
    LS_iterations=LS_iterations, 
    K_reset=K_reset, 
    chi=chi
    )
histogramLS = BoxProduct(W, P, P).tolist()

#plt.hist(histogramGD, bins=50, label="Gradient Descent (p="+str(p)+", q="+str(q)+", α="+str(alpha)+", K="+str(K)+", m=10^"+str(int(np.log10(m)))+", ε=10^"+str(int(np.log10(epsilon)))+")")
plt.hist(histogramLS, bins=50, color='purple', label="Line Search ($K_{reset}$="+str(K_reset)+", χ="+str(chi)+", m=10^"+str(int(np.log10(m)))+", M="+str(LS_iterations)+")")
#plt.xlabel("CHSH-value")
plt.xlabel("$\Phi(\mathsf{W}_{{out}})$")
plt.ylabel("Number of reruns")
plt.yscale("log")
plt.legend()
#plt.title("Histogram of the different results with a random initialization (with $\mathbf{P}=(\mathbf{PR}+\mathbf{SR})/2$, total: "+str(N)+" occurences)")
plt.show()

In [None]:
for j in range(len(histogramLS)):
    print(str(histogramLS[j])+",")

-----
## Saving

In [None]:
def line_search_with_resets(P, Q, LS_iterations, K_reset, chi, K, epsilon=1e-6):
    m = non_local_boxes.evaluate.nb_columns
    phi_flat = non_local_boxes.evaluate.phi_flat
    W = non_local_boxes.utils.random_wiring(m)
    external_grad = torch.ones(m)
    
    for j in range(1,int(1/chi)+1):
        # reset
        for _ in range(K_reset):
            # line search 
            phi_flat(W, P, Q).backward(gradient=external_grad)
            Wgrad = W.grad
            alpha = torch.ones(m)*0.01
            for _ in range(LS_iterations):
                Gains = phi_flat(W, P, Q)
                Gains_new = phi_flat(W + alpha*Wgrad, P, Q)
                mask = 0.0 + (Gains>Gains_new)
                alpha = 0.5*mask*alpha + 1.7*(1-mask)*alpha
            W = projected_wiring(W + alpha*Wgrad).detach()
            W.requires_grad=True
        W = select_best_columns(W, P, Q, min(m, int(j*m*chi))).detach()
        W.requires_grad=True
    
    #W = gradient_descent(W, P, Q, alpha*0.1, nb_iterations = K, tolerance=epsilon)

    return W