In [None]:
#Experiments with diagonalmatrix d_i=1/i. I'm printing out all of the eigenvectors I could get.

In [8]:
import jax.numpy as jnp
import jax


jax.config.update("jax_enable_x64", True)

def complex_cbrt(z):
    # Calculate the magnitude and angle of the complex number
    r = jnp.abs(z)  # Magnitude
    theta = jnp.angle(z)  # Angle in radians

    # Calculate the cubic root of the magnitude and the angle divided by 3
    root_magnitude = r ** (1/3)  # Magnitude of the cubic root
    root_angle = theta / 3  # Angle of the cubic root

    # Convert back to rectangular form (Cartesian coordinates)
    root_real = root_magnitude * jnp.cos(root_angle)
    root_imag = root_magnitude * jnp.sin(root_angle)

    return root_real + 1j * root_imag  # Return as a complex number

def optimal_lr(A,x):
  s=jnp.matmul(A,x)
  t=jnp.matmul(A,s)
  u=jnp.matmul(A,t)
  a_0=jnp.inner(x,x)
  a_1=jnp.inner(s,x)
  a_2=jnp.inner(t,x)
  a_3=jnp.inner(u,x)
  a_4=jnp.inner(u,s)
  a_5=jnp.inner(u,t)
  a_6=jnp.inner(u,u)
  r_0=4*a_2/(a_1*a_1)-2*a_1/(a_1*a_0)-2*a_3/(a_1*a_2)-2*a_3/(a_1*a_2)+a_4/(a_2*a_2)+a_2/(a_0*a_2)-2*a_1/(a_1*a_0)+a_0/(a_0*a_0)+a_2/(a_0*a_2)
  r_1=4*a_3/(a_1*a_1)-2*a_2/(a_1*a_0)-2*a_4/(a_1*a_2)-2*a_4/(a_1*a_2)+a_5/(a_2*a_2)+a_3/(a_0*a_2)-2*a_2/(a_1*a_0)+a_1/(a_0*a_0)+a_3/(a_0*a_2)
  r_2=4*a_4/(a_1*a_1)-2*a_3/(a_1*a_0)-2*a_5/(a_1*a_2)-2*a_5/(a_1*a_2)+a_6/(a_2*a_2)+a_4/(a_0*a_2)-2*a_3/(a_1*a_0)+a_2/(a_0*a_0)+a_4/(a_0*a_2)
  q_1=2*a_2/a_1-a_1/a_0-a_3/a_2
  q_2=2*a_3/a_1-a_2/a_0-a_4/a_2
  p_0=a_0
  p_1=a_1
  p_2=a_2
  a = r_0 *r_1 *q_2 - 2*r_0 *q_1 *r_2
  b = p_0*r_1*r_2 - 2*p_1*r_0*r_2 + p_2*r_0*r_1 - 2*q_1*q_2*r_0
  c = 3*p_0 *r_1 *q_2 - 3*r_0 *p_1 *q_2
  d = 2*p_0 *q_1 *q_2 + 2*p_0 *r_1 *p_2 - p_0 *p_1 *r_2 - r_0 *p_1 *p_2
  e = 2*p_0 *q_1 *p_2 - p_0 *p_1 *q_2
  lr=jnp.real(jnp.roots(jnp.array([a,b,c,d,e])))
  
  jnp.pad(lr, (0, 4-lr.size), mode='constant', constant_values=0)
  
  b=b/a
  c=c/a
  d=d/a
  e=e/a
  a=1
  Delta_0 = c**2 - 3*b*d + 12*a*e
  Delta_1 = 2*c**3 - 9*b*c*d + 27*b**2*e + 27*a*d**2 - 72*a*c*e
  Q=complex_cbrt((Delta_1+jnp.sqrt(jax.lax.complex(Delta_1**2-4*Delta_0**3, 0.)))/2)
  p=(8*a*c-3*b**2)/(8*a**2)
  q=(b**3-4*a*b*c+8*a**2*d)/(8*a**3)
  #q=((b/(2*a))**3-b/(2*a)*c/a+d/a)
  S=jnp.sqrt(-2/3*p+(Q+Delta_0/Q)/(3*a))/2
  S_1=jnp.sqrt(-4*S**2-2*p+q/S)
  S_2=jnp.sqrt(-4*S**2-2*p-q/S)
  lr_1=jnp.real(-b/(4*a)-S+S_1/2)
  lr_2=jnp.real(-b/(4*a)-S-S_1/2)
  lr_3=jnp.real(-b/(4*a)+S+S_2/2)
  lr_4=jnp.real(-b/(4*a)+S-S_2/2)
  lr_1, lr_2, lr_3, lr_4=lr[0],lr[1],lr[2],lr[3]
  #l1,l2,l3,l4=jnp.roots(jnp.array([a,b,c,d,e]))
  g_1=a_1+2*lr_1*q_1+lr_1**2*r_1
  g_2=a_1+2*lr_2*q_1+lr_2**2*r_1
  g_3=a_1+2*lr_3*q_1+lr_3**2*r_1
  g_4=a_1+2*lr_4*q_1+lr_4**2*r_1
  f_1=a_0+lr_1**2*r_0
  f_2=a_0+lr_2**2*r_0
  f_3=a_0+lr_3**2*r_0
  f_4=a_0+lr_4**2*r_0
  h_1=a_2+2*lr_1*q_2+lr_1**2*r_2
  h_2=a_2+2*lr_2*q_2+lr_2**2*r_2
  h_3=a_2+2*lr_3*q_2+lr_3**2*r_2
  h_4=a_2+2*lr_4*q_2+lr_4**2*r_2
  eigenness_1=(g_1**2)/(f_1*h_1)
  eigenness_2=(g_2**2)/(f_2*h_2)
  eigenness_3=(g_3**2)/(f_3*h_3)
  eigenness_4=(g_4**2)/(f_4*h_4)
  if eigenness_1>=eigenness_2 and eigenness_1>=eigenness_3 and eigenness_1>=eigenness_4:
    return eigenness_1, (-lr_1/a_0)*x+(2*lr_1/a_1)*s+(-lr_1/a_2)*t, 1, lr_1
  elif eigenness_2>=eigenness_1 and eigenness_2>=eigenness_3 and eigenness_2>=eigenness_4:
    return eigenness_2, (-lr_2/a_0)*x+(2*lr_2/a_1)*s+(-lr_2/a_2)*t, 2, lr_2
  elif eigenness_3>=eigenness_1 and eigenness_3>=eigenness_2 and eigenness_3>=eigenness_4:
    return eigenness_3, (-lr_3/a_0)*x+(2*lr_3/a_1)*s+(-lr_3/a_2)*t, 3, lr_3
  elif eigenness_4>=eigenness_1 and eigenness_4>=eigenness_2 and eigenness_4>=eigenness_3:
    return eigenness_4, (-lr_4/a_0)*x+(2*lr_4/a_1)*s+(-lr_4/a_2)*t, 4, lr_4


def grad_ascend_lr(A,x,steps_already):
    x=x/jnp.linalg.norm(x)
    a,v,i,lr=optimal_lr(A,x)
    step=steps_already
    x=x+v
    x=x/jnp.linalg.norm(x)
    while a<0.9999 and step<2000:
      a,v,i,lr=optimal_lr(A,x)
      step+=1 
      x=x+v
      x=x/jnp.linalg.norm(x)
      #print("           ", np.ndarray.tolist(np.argsort(np.abs(jnp.array(x)))), np.ndarray.tolist(np.argsort(np.diag(jnp.array(A)))), end="\r")
        
    return x/jnp.sqrt(jnp.inner(x,x)), jnp.inner(jnp.matmul(A,x), x)/jnp.inner(x,x), jnp.linalg.norm(jnp.matmul(A,x)-jnp.inner(jnp.matmul(A,x), x)/jnp.inner(x,x)*x), step

import jax.random as jra
import numpy as np
import sys

size=200

count=np.zeros(size)
num_steps=[]
coords=[]
breaks=[]

num_experiments=100000

def grad_ascend_lr_transform(C,x,j,steps_already):
    f=jra.uniform(key = jra.PRNGKey(3*j))
    B=C-f*I
    B=B@B
    #print(jnp.argsort(jnp.diag(B)))
    s=(10*jra.cauchy(key = jra.PRNGKey(3*j+2))*f)**2
    B=(2*s*B+B@B)/((s+max(f**2,(1-f)**2))**2-s**2)
    #print(jnp.argsort(jnp.diag(B)))
    t=jra.uniform(key = jra.PRNGKey(3*j+2))
    #print(jnp.argsort(jnp.diag(B)))
    B=t*I-B
    B=(t**2)*I-B@B
    #print(jnp.argsort(jnp.diag(B)))
    #print("")
    
    return grad_ascend_lr(B,x, steps_already)


A=jnp.diag(1/(1+jnp.array(range(size))))
I=jnp.diag(jnp.ones((size,)))
for i in range(num_experiments):
    j=0
    x=jra.normal(key = jra.PRNGKey(i), shape=(size,))
    t=grad_ascend_lr_transform(A,x,i,0)
    steps=t[-1]
    while jnp.max(jnp.abs(t[0]))<0.9 and steps<2000:
        j+=1
        t=grad_ascend_lr_transform(A,t[0],j*i,steps)
        steps=t[-1]
        
        

    num_steps.append(steps)
    count[jnp.where(jnp.abs(t[0])>0.9)]+=1
    coords.append([i,j,jnp.where(jnp.abs(t[0])>0.9)])
    print(f"     {i}"+f" {max(num_steps)}"+f"{list(np.where(count>0)[0])}"+f"{np.sum(count)}", end="\r")
    
    
print(count)
print(np.sum(count))
print(max(num_steps))
print(breaks)

     21320 2000[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 49, 50, 51, 52, 53, 54, 57, 65, 66, 72, 126, 198, 199]12292.0

KeyboardInterrupt: 

In [9]:
print(count)

[8.355e+03 1.460e+03 8.430e+02 3.080e+02 1.520e+02 1.570e+02 1.670e+02
 1.400e+02 1.130e+02 1.020e+02 7.700e+01 4.900e+01 3.600e+01 1.900e+01
 1.900e+01 1.400e+01 2.300e+01 1.400e+01 8.000e+00 1.200e+01 6.000e+00
 1.300e+01 7.000e+00 1.200e+01 1.200e+01 1.400e+01 1.600e+01 1.200e+01
 1.400e+01 1.400e+01 9.000e+00 1.200e+01 7.000e+00 7.000e+00 7.000e+00
 8.000e+00 5.000e+00 2.000e+00 6.000e+00 4.000e+00 4.000e+00 3.000e+00
 2.000e+00 2.000e+00 4.000e+00 1.000e+00 1.000e+00 0.000e+00 0.000e+00
 1.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 2.000e+00 0.000e+00
 0.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 1.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000

In [10]:
print(coords)

[[0, 0, (Array([0], dtype=int64),)], [1, 0, (Array([], shape=(0,), dtype=int64),)], [2, 2, (Array([10], dtype=int64),)], [3, 0, (Array([0], dtype=int64),)], [4, 0, (Array([7], dtype=int64),)], [5, 10, (Array([0], dtype=int64),)], [6, 2, (Array([0], dtype=int64),)], [7, 0, (Array([0], dtype=int64),)], [8, 0, (Array([0], dtype=int64),)], [9, 0, (Array([], shape=(0,), dtype=int64),)], [10, 2, (Array([0], dtype=int64),)], [11, 6, (Array([], shape=(0,), dtype=int64),)], [12, 0, (Array([0], dtype=int64),)], [13, 0, (Array([0], dtype=int64),)], [14, 0, (Array([2], dtype=int64),)], [15, 8, (Array([], shape=(0,), dtype=int64),)], [16, 4, (Array([], shape=(0,), dtype=int64),)], [17, 7, (Array([], shape=(0,), dtype=int64),)], [18, 0, (Array([0], dtype=int64),)], [19, 0, (Array([0], dtype=int64),)], [20, 0, (Array([0], dtype=int64),)], [21, 2, (Array([0], dtype=int64),)], [22, 2, (Array([0], dtype=int64),)], [23, 0, (Array([1], dtype=int64),)], [24, 0, (Array([0], dtype=int64),)], [25, 2, (Array([

In [13]:
for e in coords:
    if e[2][0].size>0:
        if e[2][0]>20:
            print(e[0], e[1])
        

94 3
153 2
245 3
282 3
293 2
339 2
507 0
523 2
754 0
782 2
913 4
936 2
1046 6
1204 3
1215 2
1249 0
1254 0
1268 2
1332 0
1399 6
1411 0
1497 0
1566 0
1633 0
1665 0
1696 0
1931 0
1999 2
2086 0
2209 0
2328 2
2336 0
2337 0
2449 2
2491 0
2511 0
2536 0
2549 0
2609 3
2856 3
2975 0
3058 0
3266 0
3284 0
3331 0
3371 0
3472 0
3569 0
3612 0
3717 0
3748 0
3815 0
4013 0
4053 13
4219 0
4289 3
4323 0
4455 0
4472 0
4592 0
4647 11
4652 0
4654 2
4926 10
5389 0
5483 0
5514 10
5589 0
5635 0
5676 4
5727 0
5775 0
5888 2
5957 2
6366 0
6376 0
6438 0
6443 0
6757 5
6856 2
6950 0
7157 0
7222 2
7269 0
7270 0
7300 0
7371 0
7558 0
7593 0
7760 2
7791 8
7972 0
8064 0
8084 0
8105 0
8227 0
8365 0
8549 0
8593 3
8597 2
8608 0
8650 0
8900 0
8916 8
8923 0
9033 0
9052 2
9059 0
9141 0
9184 0
9275 0
9308 0
9366 2
9404 0
9421 4
9444 3
9665 0
10009 0
10091 2
10179 2
10290 0
10374 5
10390 2
10430 0
10513 2
10683 3
11253 0
11380 0
11516 0
11573 0
11665 0
11861 0
11996 0
12318 0
12369 0
12416 0
12430 0
12463 0
12471 3
12755 0
12759 

In [16]:
for e in coords:
    if e[1]>3 and e[2][0].size>0:
        print(e[0],e[1],e[2][0])

5 10 [0]
45 11 [1]
90 8 [0]
107 5 [1]
133 4 [3]
150 4 [2]
160 6 [3]
180 9 [2]
183 5 [0]
196 9 [1]
228 4 [0]
237 4 [7]
241 10 [0]
270 5 [0]
277 22 [9]
284 8 [2]
287 7 [2]
291 4 [0]
310 4 [0]
315 6 [1]
325 4 [1]
328 11 [2]
330 6 [0]
335 5 [1]
347 7 [1]
355 8 [0]
362 10 [1]
364 7 [0]
366 6 [0]
369 5 [2]
386 4 [0]
392 7 [2]
414 4 [0]
445 7 [0]
452 4 [1]
455 4 [8]
486 8 [0]
490 5 [0]
503 4 [1]
508 4 [2]
512 5 [16]
539 4 [1]
540 8 [0]
560 9 [1]
570 4 [1]
581 4 [1]
597 5 [0]
602 4 [8]
606 4 [2]
614 7 [10]
624 4 [0]
628 8 [2]
631 6 [0]
642 9 [3]
644 4 [1]
656 4 [0]
686 6 [0]
693 6 [3]
694 6 [1]
711 10 [0]
714 4 [5]
725 10 [0]
727 5 [1]
743 4 [0]
756 7 [2]
757 10 [2]
759 9 [0]
780 5 [0]
787 4 [0]
790 5 [9]
796 6 [2]
802 8 [1]
812 4 [2]
825 4 [2]
827 5 [0]
841 11 [0]
842 6 [0]
843 7 [0]
847 8 [1]
867 6 [1]
890 8 [1]
913 4 [25]
920 5 [0]
923 8 [3]
941 20 [2]
955 4 [0]
964 4 [1]
972 4 [0]
973 20 [0]
980 5 [0]
984 4 [0]
989 5 [4]
993 6 [2]
997 4 [0]
1017 10 [6]
1023 10 [2]
1027 6 [0]
1038 7 [6]
104

In [17]:
s=0
for e in coords:
    if e[2][0].size>0:
        s=max(s,e[1])
print(s)

40


In [18]:
fyh=[]
for e in coords:
    fyh.append(e[1])
print(fyh)

[0, 0, 2, 0, 0, 10, 2, 0, 0, 0, 2, 6, 0, 0, 0, 8, 4, 7, 0, 0, 0, 2, 2, 0, 0, 2, 0, 2, 0, 2, 4, 2, 2, 0, 0, 0, 4, 0, 3, 0, 3, 2, 0, 2, 4, 11, 0, 2, 2, 0, 0, 3, 2, 3, 0, 5, 0, 2, 0, 2, 2, 0, 0, 2, 0, 3, 0, 2, 2, 2, 6, 0, 2, 0, 0, 3, 0, 0, 3, 0, 5, 0, 0, 0, 2, 2, 0, 0, 2, 2, 8, 2, 0, 2, 3, 0, 0, 2, 3, 0, 4, 0, 0, 2, 0, 4, 3, 5, 2, 0, 2, 0, 3, 0, 0, 0, 3, 0, 0, 2, 0, 0, 0, 4, 0, 0, 2, 0, 3, 0, 3, 4, 2, 4, 0, 7, 0, 3, 0, 2, 0, 2, 3, 4, 0, 4, 6, 2, 9, 0, 4, 3, 2, 2, 0, 3, 3, 0, 0, 0, 6, 2, 4, 0, 0, 3, 0, 2, 0, 2, 0, 0, 2, 2, 2, 0, 0, 0, 2, 5, 9, 0, 0, 5, 6, 2, 0, 0, 2, 2, 0, 3, 2, 0, 0, 0, 9, 0, 2, 2, 0, 0, 0, 6, 0, 0, 0, 8, 0, 4, 2, 2, 8, 0, 2, 2, 0, 2, 0, 0, 2, 3, 6, 2, 2, 2, 3, 0, 4, 2, 0, 2, 0, 2, 0, 3, 2, 4, 2, 2, 2, 10, 3, 5, 0, 3, 2, 3, 0, 0, 0, 2, 0, 0, 2, 2, 11, 0, 3, 0, 2, 2, 2, 4, 0, 0, 5, 0, 0, 3, 5, 2, 2, 2, 2, 0, 4, 22, 0, 0, 13, 0, 3, 3, 8, 0, 2, 7, 2, 0, 0, 4, 3, 2, 0, 2, 5, 0, 3, 14, 2, 6, 8, 0, 3, 2, 0, 0, 2, 2, 4, 0, 3, 8, 2, 6, 2, 2, 0, 6, 0, 2, 0, 3, 0, 4, 0, 4, 11, 2, 6