In [23]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians
    #print("z")
    #print(z)
    #print("r")
    #print(r)
    #print("theta")
    #print(theta)

    # Calculate the cubic root of the magnitude and the angle divided by 3
    r = r ** (1/3)  # Magnitude of the cubic root
    theta = theta / 3  # Angle of the cubic root
    #print("root_magnitude")
    #print(r)
    #print("root_angle")
    #print(theta)

    return r * torch.cos(theta) + 1j * r * torch.sin(theta)  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e = e/a
    d = d/a
    c = c/a
    b = b/a
    i, lr_1 = torch.frexp(e)
    i, lr_2 = torch.frexp(d)
    i, lr_3 = torch.frexp(c)
    i, lr_4 = torch.frexp(b)
    i = torch.max(
        torch.cat((2 + 3 * torch.max(torch.cat((2 * lr_3, 2 + lr_4 + lr_2, 4 + lr_1), 
                                           dim = -1), dim = -1, keepdim = True)[0], 
                  2 * torch.max(torch.cat((1 + 3 * lr_3, 3 + lr_4 + lr_3 + lr_2, 5 + 2 * lr_4 + lr_1, 5 + 2 * lr_2, 8 + lr_3 + lr_1), 
                                        dim = -1), dim = -1, keepdim = True)[0]), 
                  dim = -1), dim = -1, keepdim = True)[0]//12
        
    a = c ** 2 - 3 * b * d + 12 * e
    e = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e
    e = torch.ldexp(e, -6*i)
    a = torch.ldexp(a, -4*i)
    #print("Delta_0")
    #print(a)
    #print("Delta_1")
    #print(e)
    #print("under sqrt")
    #print(e ** 2 - 4 * a ** 3)
    #print("sqrt")
    #print(torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64)))

    Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2) 
    #print("Q")
    #print(Q)
    
    # Calculate p and q
    e = torch.ldexp(c - 0.375 * b ** 2, -2*i)
    d = torch.ldexp((0.5*b) ** 3 - 0.5 * b * c + d, -3*i)
    #print("q")
    #print(d)
    S = torch.sqrt(-2 / 3 * e + (Q + a / Q) / 3) / 2
    a = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e + d / S)
    c = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e - d / S)
    b = torch.ldexp(-b, -i-2)

    lr_1 = torch.ldexp(torch.real(b - S + a), i)
    lr_2 = torch.ldexp(torch.real(b - S - a), i)
    lr_3 = torch.ldexp(torch.real(b + S + c), i)
    lr_4 = torch.ldexp(torch.real(b + S - c), i)
    
    return torch.cat([lr_1, lr_2, lr_3, lr_4],dim = -1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    a = c - b ** 2 / 3
    d = 2/27*b**3 - b*c/3 + d
    c = (a/3)**3 + (d/2)**2
    c = torch.sqrt(c)
    alpha = -d/2 + c
    beta = -d/2 - c
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta) + 1j * sqrt_3 * (alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta) - 1j * sqrt_3 * (alpha-beta))/2)
    return torch.cat([lr1, lr2, lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1, x2, torch.zeros_like(x1), torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.sum(x*x, dim = -1, keepdim = True)
    a_1 = torch.sum(u*x, dim = -1, keepdim = True)
    a_2 = torch.sum(v*x, dim = -1, keepdim = True)
    a_3 = torch.sum(w*x, dim = -1, keepdim = True)
    a_4 = torch.sum(w*u, dim = -1, keepdim = True)
    a_5 = torch.sum(w*v, dim = -1, keepdim = True)
    a_6 = torch.sum(w*w, dim = -1, keepdim = True)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    #p_0 = a_0
    #p_1 = a_1
    #p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = a_0 * r_1 * r_2 - 2 * a_1 * r_0 * r_2 + a_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * a_0 * r_1 * q_2 - 3 * r_0 * a_1 * q_2
    d = 2 * a_0 * q_1 * q_2 + 2 * a_0 * r_1 * a_2 - a_0 * a_1 * r_2 - r_0 * a_1 * a_2
    e = 2 * a_0 * q_1 * a_2 - a_0 * a_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), lr * (-x / a_0 + 2 * u / a_1 - v / a_2), n, lr
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x = x + v
    return x/torch.linalg.norm(x, dim = -1, keepdim = True)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A, x)
    x = update_vector(x, v)
    steps = steps_already + 1
    cond = (1 - e > threshold) & (steps < steps_max)
    while cond.any():
        steps = torch.where(cond, steps + 1, steps)
        e,v,i,l = optimal_lr(A, x)
        x = torch.where(cond, update_vector(x, v), x)
        cond = (1 - e > threshold) & (steps < steps_max)
    v=[]
    f=A(x)
    #print_gpu_memory()
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (
            f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x
            )/torch.linalg.norm(f, dim=-1).unsqueeze(-1), steps

import torch.nn as nn

def add(x,y):
    return x+y

class square(nn.Module):
    def __init__(self, g):
        super(square, self).__init__()
        self.g=g
    
    def forward(self, x):
        return self.g(self.g(x))

class minus_f(nn.Module):
    def __init__(self, g, f):
        super(minus_f, self).__init__()
        self.g=g
        self.f=f

    def forward(self, x):
        return self.g(x) - self.f*x

class t_minus(nn.Module):
    def __init__(self, g, t):
        super(t_minus, self).__init__()
        self.g=g
        self.t=t

    def forward(self, x):
        return self.t*x - self.g(x)


class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=((torch.rand(batch_size,1).cuda())*torch.max(torch.cat([self.f**2, (1-self.f)**2], dim=1), dim=1, keepdim = True)[0]).cuda()
        self.z=torch.max(torch.cat([self.t**2, (self.t-self.f**2)**2, (self.t-(1-self.f)**2)**2], dim=1), dim=1, keepdim = True)[0]
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
        self.minus_f = minus_f(self.A, self.f)
        self.square_1 = square(self.minus_f)
        self.t_minus = t_minus(self.square_1, self.t)
        self.square_2 = square(self.t_minus)
    
    def forward_1(self,x):
        return x - self.square_2(x)/self.z

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=16384
matrix_size=256
A=torch.diag(torch.tensor(range(matrix_size), dtype=torch.float64)/matrix_size)

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float64)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

probs = 0.5*torch.ones((batch_size, 1)).cuda()
delta = torch.tensor(1e-6).cuda()

class truncated(nn.Module):
    def __init__(self, f, x):
        super(truncated, self).__init__()
        
        y = f(x)
        xx = torch.sum(x**2, dim = -1, keepdim = True)
        xAx = torch.sum(x*y, dim = -1, keepdim = True)
        xAAx = torch.sum(y**2, dim = -1, keepdim = True)
        mean = xAx/xx
        sigma = torch.sqrt(xAAx*xx - xAx**2)/xx
        self.mean = mean + (-1)**torch.bernoulli(probs) * sigma
        self.sigma = torch.sqrt(-2*sigma*torch.log(delta))
        self.lambda_min = self.mean - self.sigma
        self.lambda_max = self.mean + self.sigma
        self.f = f

    def forward(self, x):
        y = self.f(x) - self.mean * x
        y = self.f(y) - self.mean * y
        return x - y / torch.min(torch.cat((
            (self.lambda_min - self.mean)**2, 
            (self.lambda_max - self.mean)**2), 
                dim = -1), dim = -1, keepdim = True)[0]

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
stepses=[]
with torch.cuda.device(0):
    for i in range(1):
        x = torch.empty((batch_size,matrix_size), dtype = torch.float64).normal_(mean=0,std=1).cuda()
        steps = torch.zeros((batch_size,1)).cuda()
        t[0] = x/torch.linalg.norm(x, dim=-1, keepdim=True)
        cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        while torch.any((steps<1000) & cond):
            g = modified(expr_1,f,batch_size)
            s = grad_ascend_lr(g,t[0],1e-6,steps,1000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = torch.where(cond, s[-1], steps)
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
            print(i, list(np.where(torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=0).cpu().numpy()>0)[0]), end='\r')
        result+= torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=0).cpu().numpy()
        print(i, list(np.where(result>0)[0]), end='\r')
print('\n'+np.sum(result))

t


  Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2)


0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 22

TypeError: can only concatenate str (not "numpy.float64") to str

In [24]:
print(np.sum(result))

4647.0


In [28]:
print(np.sum(np.where(result>0, 1, 0)))

254


In [29]:
print(np.where(np.where(result>0, 0, 1)>0))

(array([160, 172]),)


In [25]:
np.min(steps.cpu().numpy()[np.where(torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=1,keepdim=True).cpu().numpy()>0)])

506.0

In [27]:
np.sort(steps.cpu().numpy()[np.where(torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=1,keepdim=True).cpu().numpy()>0)])[:20]

array([ 506.,  624.,  779.,  781.,  783.,  826.,  866.,  880.,  887.,
        922.,  947.,  961.,  965.,  966.,  981.,  983.,  989.,  990.,
        999., 1000.], dtype=float32)

In [31]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians
    #print("z")
    #print(z)
    #print("r")
    #print(r)
    #print("theta")
    #print(theta)

    # Calculate the cubic root of the magnitude and the angle divided by 3
    r = r ** (1/3)  # Magnitude of the cubic root
    theta = theta / 3  # Angle of the cubic root
    #print("root_magnitude")
    #print(r)
    #print("root_angle")
    #print(theta)

    return r * torch.cos(theta) + 1j * r * torch.sin(theta)  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e = e/a
    d = d/a
    c = c/a
    b = b/a
    i, lr_1 = torch.frexp(e)
    i, lr_2 = torch.frexp(d)
    i, lr_3 = torch.frexp(c)
    i, lr_4 = torch.frexp(b)
    i = torch.max(
        torch.cat((2 + 3 * torch.max(torch.cat((2 * lr_3, 2 + lr_4 + lr_2, 4 + lr_1), 
                                           dim = -1), dim = -1, keepdim = True)[0], 
                  2 * torch.max(torch.cat((1 + 3 * lr_3, 3 + lr_4 + lr_3 + lr_2, 5 + 2 * lr_4 + lr_1, 5 + 2 * lr_2, 8 + lr_3 + lr_1), 
                                        dim = -1), dim = -1, keepdim = True)[0]), 
                  dim = -1), dim = -1, keepdim = True)[0]//12
        
    a = c ** 2 - 3 * b * d + 12 * e
    e = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e
    e = torch.ldexp(e, -6*i)
    a = torch.ldexp(a, -4*i)
    #print("Delta_0")
    #print(a)
    #print("Delta_1")
    #print(e)
    #print("under sqrt")
    #print(e ** 2 - 4 * a ** 3)
    #print("sqrt")
    #print(torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64)))

    Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2) 
    #print("Q")
    #print(Q)
    
    # Calculate p and q
    e = torch.ldexp(c - 0.375 * b ** 2, -2*i)
    d = torch.ldexp((0.5*b) ** 3 - 0.5 * b * c + d, -3*i)
    #print("q")
    #print(d)
    S = torch.sqrt(-2 / 3 * e + (Q + a / Q) / 3) / 2
    a = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e + d / S)
    c = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e - d / S)
    b = torch.ldexp(-b, -i-2)

    lr_1 = torch.ldexp(torch.real(b - S + a), i)
    lr_2 = torch.ldexp(torch.real(b - S - a), i)
    lr_3 = torch.ldexp(torch.real(b + S + c), i)
    lr_4 = torch.ldexp(torch.real(b + S - c), i)
    
    return torch.cat([lr_1, lr_2, lr_3, lr_4],dim = -1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    a = c - b ** 2 / 3
    d = 2/27*b**3 - b*c/3 + d
    c = (a/3)**3 + (d/2)**2
    c = torch.sqrt(c)
    alpha = -d/2 + c
    beta = -d/2 - c
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta) + 1j * sqrt_3 * (alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta) - 1j * sqrt_3 * (alpha-beta))/2)
    return torch.cat([lr1, lr2, lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1, x2, torch.zeros_like(x1), torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.sum(x*x, dim = -1, keepdim = True)
    a_1 = torch.sum(u*x, dim = -1, keepdim = True)
    a_2 = torch.sum(v*x, dim = -1, keepdim = True)
    a_3 = torch.sum(w*x, dim = -1, keepdim = True)
    a_4 = torch.sum(w*u, dim = -1, keepdim = True)
    a_5 = torch.sum(w*v, dim = -1, keepdim = True)
    a_6 = torch.sum(w*w, dim = -1, keepdim = True)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    #p_0 = a_0
    #p_1 = a_1
    #p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = a_0 * r_1 * r_2 - 2 * a_1 * r_0 * r_2 + a_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * a_0 * r_1 * q_2 - 3 * r_0 * a_1 * q_2
    d = 2 * a_0 * q_1 * q_2 + 2 * a_0 * r_1 * a_2 - a_0 * a_1 * r_2 - r_0 * a_1 * a_2
    e = 2 * a_0 * q_1 * a_2 - a_0 * a_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), lr * (-x / a_0 + 2 * u / a_1 - v / a_2), n, lr
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x = x + v
    return x/torch.linalg.norm(x, dim = -1, keepdim = True)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A, x)
    x = update_vector(x, v)
    steps = steps_already + 1
    cond = (1 - e > threshold) & (steps < steps_max)
    while cond.any():
        steps = torch.where(cond, steps + 1, steps)
        e,v,i,l = optimal_lr(A, x)
        x = torch.where(cond, update_vector(x, v), x)
        cond = (1 - e > threshold) & (steps < steps_max)
    v=[]
    f=A(x)
    #print_gpu_memory()
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (
            f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x
            )/torch.linalg.norm(f, dim=-1).unsqueeze(-1), steps

import torch.nn as nn

def add(x,y):
    return x+y

class square(nn.Module):
    def __init__(self, g):
        super(square, self).__init__()
        self.g=g
    
    def forward(self, x):
        return self.g(self.g(x))

class minus_f(nn.Module):
    def __init__(self, g, f):
        super(minus_f, self).__init__()
        self.g=g
        self.f=f

    def forward(self, x):
        return self.g(x) - self.f*x

class t_minus(nn.Module):
    def __init__(self, g, t):
        super(t_minus, self).__init__()
        self.g=g
        self.t=t

    def forward(self, x):
        return self.t*x - self.g(x)


class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=((torch.rand(batch_size,1).cuda())*torch.max(torch.cat([self.f**2, (1-self.f)**2], dim=1), dim=1, keepdim = True)[0]).cuda()
        self.z=torch.max(torch.cat([self.t**2, (self.t-self.f**2)**2, (self.t-(1-self.f)**2)**2], dim=1), dim=1, keepdim = True)[0]
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
        self.minus_f = minus_f(self.A, self.f)
        self.square_1 = square(self.minus_f)
        self.t_minus = t_minus(self.square_1, self.t)
        self.square_2 = square(self.t_minus)
    
    def forward_1(self,x):
        return x - self.square_2(x)/self.z

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=16384
matrix_size=128
A=torch.diag(1/(1+torch.tensor(range(matrix_size), dtype=torch.float64)))

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float64)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

probs = 0.5*torch.ones((batch_size, 1)).cuda()
delta = torch.tensor(1e-6).cuda()

class truncated(nn.Module):
    def __init__(self, f, x):
        super(truncated, self).__init__()
        
        y = f(x)
        xx = torch.sum(x**2, dim = -1, keepdim = True)
        xAx = torch.sum(x*y, dim = -1, keepdim = True)
        xAAx = torch.sum(y**2, dim = -1, keepdim = True)
        mean = xAx/xx
        sigma = torch.sqrt(xAAx*xx - xAx**2)/xx
        self.mean = mean + (-1)**torch.bernoulli(probs) * sigma
        self.sigma = torch.sqrt(-2*sigma*torch.log(delta))
        self.lambda_min = self.mean - self.sigma
        self.lambda_max = self.mean + self.sigma
        self.f = f

    def forward(self, x):
        y = self.f(x) - self.mean * x
        y = self.f(y) - self.mean * y
        return x - y / torch.min(torch.cat((
            (self.lambda_min - self.mean)**2, 
            (self.lambda_max - self.mean)**2), 
                dim = -1), dim = -1, keepdim = True)[0]

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
stepses=[]
with torch.cuda.device(0):
    for i in range(1):
        x = torch.empty((batch_size,matrix_size), dtype = torch.float64).normal_(mean=0,std=1).cuda()
        steps = torch.zeros((batch_size,1)).cuda()
        t[0] = x/torch.linalg.norm(x, dim=-1, keepdim=True)
        cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        while torch.any((steps<1000) & cond):
            g = modified(expr_1,f,batch_size)
            s = grad_ascend_lr(g,t[0],1e-6,steps,1000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = torch.where(cond, s[-1], steps)
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
            print(i, list(np.where(torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=0).cpu().numpy()>0)[0]), end='\r')
        result+= torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=0).cpu().numpy()
        print(i, list(np.where(result>0)[0]), end='\r')
print('\n'+np.sum(result))

t


  Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2)


0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 25, 26, 28, 29, 30, 32, 34, 35, 37, 38, 39, 41, 44, 45, 49, 60]

TypeError: can only concatenate str (not "numpy.float64") to str

In [35]:
print(result)

[4.824e+03 2.549e+03 1.187e+03 1.052e+03 9.570e+02 7.230e+02 5.510e+02
 4.090e+02 2.310e+02 1.750e+02 1.230e+02 7.700e+01 5.900e+01 2.900e+01
 2.500e+01 1.500e+01 1.200e+01 1.300e+01 5.000e+00 1.000e+00 5.000e+00
 5.000e+00 2.000e+00 2.000e+00 0.000e+00 4.000e+00 2.000e+00 0.000e+00
 1.000e+00 2.000e+00 4.000e+00 0.000e+00 1.000e+00 0.000e+00 2.000e+00
 1.000e+00 0.000e+00 2.000e+00 1.000e+00 3.000e+00 0.000e+00 2.000e+00
 0.000e+00 0.000e+00 1.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00
 1.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000

In [34]:
np.sort(steps.cpu().numpy()[np.where(torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=1,keepdim=True).cpu().numpy()>0)])[:2000]

array([1., 1., 1., ..., 3., 3., 3.], dtype=float32)

In [36]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians
    #print("z")
    #print(z)
    #print("r")
    #print(r)
    #print("theta")
    #print(theta)

    # Calculate the cubic root of the magnitude and the angle divided by 3
    r = r ** (1/3)  # Magnitude of the cubic root
    theta = theta / 3  # Angle of the cubic root
    #print("root_magnitude")
    #print(r)
    #print("root_angle")
    #print(theta)

    return r * torch.cos(theta) + 1j * r * torch.sin(theta)  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e = e/a
    d = d/a
    c = c/a
    b = b/a
    i, lr_1 = torch.frexp(e)
    i, lr_2 = torch.frexp(d)
    i, lr_3 = torch.frexp(c)
    i, lr_4 = torch.frexp(b)
    i = torch.max(
        torch.cat((2 + 3 * torch.max(torch.cat((2 * lr_3, 2 + lr_4 + lr_2, 4 + lr_1), 
                                           dim = -1), dim = -1, keepdim = True)[0], 
                  2 * torch.max(torch.cat((1 + 3 * lr_3, 3 + lr_4 + lr_3 + lr_2, 5 + 2 * lr_4 + lr_1, 5 + 2 * lr_2, 8 + lr_3 + lr_1), 
                                        dim = -1), dim = -1, keepdim = True)[0]), 
                  dim = -1), dim = -1, keepdim = True)[0]//12
        
    a = c ** 2 - 3 * b * d + 12 * e
    e = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e
    e = torch.ldexp(e, -6*i)
    a = torch.ldexp(a, -4*i)
    #print("Delta_0")
    #print(a)
    #print("Delta_1")
    #print(e)
    #print("under sqrt")
    #print(e ** 2 - 4 * a ** 3)
    #print("sqrt")
    #print(torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64)))

    Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2) 
    #print("Q")
    #print(Q)
    
    # Calculate p and q
    e = torch.ldexp(c - 0.375 * b ** 2, -2*i)
    d = torch.ldexp((0.5*b) ** 3 - 0.5 * b * c + d, -3*i)
    #print("q")
    #print(d)
    S = torch.sqrt(-2 / 3 * e + (Q + a / Q) / 3) / 2
    a = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e + d / S)
    c = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e - d / S)
    b = torch.ldexp(-b, -i-2)

    lr_1 = torch.ldexp(torch.real(b - S + a), i)
    lr_2 = torch.ldexp(torch.real(b - S - a), i)
    lr_3 = torch.ldexp(torch.real(b + S + c), i)
    lr_4 = torch.ldexp(torch.real(b + S - c), i)
    
    return torch.cat([lr_1, lr_2, lr_3, lr_4],dim = -1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    a = c - b ** 2 / 3
    d = 2/27*b**3 - b*c/3 + d
    c = (a/3)**3 + (d/2)**2
    c = torch.sqrt(c)
    alpha = -d/2 + c
    beta = -d/2 - c
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta) + 1j * sqrt_3 * (alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta) - 1j * sqrt_3 * (alpha-beta))/2)
    return torch.cat([lr1, lr2, lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1, x2, torch.zeros_like(x1), torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.sum(x*x, dim = -1, keepdim = True)
    a_1 = torch.sum(u*x, dim = -1, keepdim = True)
    a_2 = torch.sum(v*x, dim = -1, keepdim = True)
    a_3 = torch.sum(w*x, dim = -1, keepdim = True)
    a_4 = torch.sum(w*u, dim = -1, keepdim = True)
    a_5 = torch.sum(w*v, dim = -1, keepdim = True)
    a_6 = torch.sum(w*w, dim = -1, keepdim = True)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    #p_0 = a_0
    #p_1 = a_1
    #p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = a_0 * r_1 * r_2 - 2 * a_1 * r_0 * r_2 + a_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * a_0 * r_1 * q_2 - 3 * r_0 * a_1 * q_2
    d = 2 * a_0 * q_1 * q_2 + 2 * a_0 * r_1 * a_2 - a_0 * a_1 * r_2 - r_0 * a_1 * a_2
    e = 2 * a_0 * q_1 * a_2 - a_0 * a_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), lr * (-x / a_0 + 2 * u / a_1 - v / a_2), n, lr
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x = x + v
    return x/torch.linalg.norm(x, dim = -1, keepdim = True)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A, x)
    x = update_vector(x, v)
    steps = steps_already + 1
    cond = (1 - e > threshold) & (steps < steps_max)
    while cond.any():
        steps = torch.where(cond, steps + 1, steps)
        e,v,i,l = optimal_lr(A, x)
        x = torch.where(cond, update_vector(x, v), x)
        cond = (1 - e > threshold) & (steps < steps_max)
    v=[]
    f=A(x)
    #print_gpu_memory()
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (
            f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x
            )/torch.linalg.norm(f, dim=-1).unsqueeze(-1), steps

import torch.nn as nn

def add(x,y):
    return x+y

class square(nn.Module):
    def __init__(self, g):
        super(square, self).__init__()
        self.g=g
    
    def forward(self, x):
        return self.g(self.g(x))

class minus_f(nn.Module):
    def __init__(self, g, f):
        super(minus_f, self).__init__()
        self.g=g
        self.f=f

    def forward(self, x):
        return self.g(x) - self.f*x

class t_minus(nn.Module):
    def __init__(self, g, t):
        super(t_minus, self).__init__()
        self.g=g
        self.t=t

    def forward(self, x):
        return self.t*x - self.g(x)


class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=((torch.rand(batch_size,1).cuda())*torch.max(torch.cat([self.f**2, (1-self.f)**2], dim=1), dim=1, keepdim = True)[0]).cuda()
        self.z=torch.max(torch.cat([self.t**2, (self.t-self.f**2)**2, (self.t-(1-self.f)**2)**2], dim=1), dim=1, keepdim = True)[0]
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
        self.minus_f = minus_f(self.A, self.f)
        self.square_1 = square(self.minus_f)
        self.t_minus = t_minus(self.square_1, self.t)
        self.square_2 = square(self.t_minus)
    
    def forward_1(self,x):
        return x - self.square_2(x)/self.z

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=16384
matrix_size=128
A=torch.diag(1/(1+torch.tensor(range(matrix_size), dtype=torch.float64)))

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float64)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

probs = 0.5*torch.ones((batch_size, 1)).cuda()
delta = torch.tensor(1e-6).cuda()

class truncated(nn.Module):
    def __init__(self, f, x):
        super(truncated, self).__init__()
        
        y = f(x)
        xx = torch.sum(x**2, dim = -1, keepdim = True)
        xAx = torch.sum(x*y, dim = -1, keepdim = True)
        xAAx = torch.sum(y**2, dim = -1, keepdim = True)
        mean = xAx/xx
        sigma = torch.sqrt(xAAx*xx - xAx**2)/xx
        self.mean = mean + (-1)**torch.bernoulli(probs) * sigma
        self.sigma = torch.sqrt(-2*sigma*torch.log(delta))
        self.lambda_min = self.mean - self.sigma
        self.lambda_max = self.mean + self.sigma
        self.f = f

    def forward(self, x):
        y = self.f(x) - self.mean * x
        y = self.f(y) - self.mean * y
        return x - y / torch.min(torch.cat((
            (self.lambda_min - self.mean)**2, 
            (self.lambda_max - self.mean)**2), 
                dim = -1), dim = -1, keepdim = True)[0]

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
stepses=[]
with torch.cuda.device(0):
    for i in range(1):
        x = torch.empty((batch_size, matrix_size), dtype=torch.float64).normal_(mean = 0, std = 1).cuda()
        steps = torch.zeros((batch_size, 1)).cuda()
        stepses.append(torch.zeros((batch_size, 1)).cuda())
        t[0] = x/torch.linalg.norm(x, dim = -1, keepdim = True)
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = modified(expr_1, f, batch_size)
        s = grad_ascend_lr(g, t[0], 1e-6, steps, 1000)
        t[0] = torch.where(cond, s[0], t[0])
        steps = s[-1]
        stepses[i] = torch.where(cond, s[-1], stepses[i])
        s = [1]
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = truncated(f, t[0])
        while torch.any((steps < 1000) & cond):
            s = grad_ascend_lr(g, t[0], 1e-10, steps, 1000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = s[-1]
            stepses[i] = torch.where(cond, s[-1], stepses[i])
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
            g = truncated(f, t[0])
        result+= torch.sum(torch.where(torch.abs(h(t[0])) >= 0.9, 1, 0), dim = 0).cpu().numpy()
        print(i, list(np.where(result > 0)[0]), np.sum(result) / ((i + 1) * batch_size), end='\r')


t


  Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2)


0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 30, 31, 32, 33, 34, 35, 37, 38, 40, 41, 43, 49, 55, 57] 0.6531982421875

In [37]:
np.sort(steps.cpu().numpy()[np.where(torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=1,keepdim=True).cpu().numpy()>0)])[:2000]

array([ 14.,  16.,  17., ..., 103., 103., 103.], dtype=float32)

In [38]:
print(result)

[4.711e+03 1.675e+03 8.690e+02 8.170e+02 6.890e+02 5.580e+02 4.200e+02
 3.280e+02 2.020e+02 1.040e+02 8.800e+01 5.900e+01 4.400e+01 2.700e+01
 2.100e+01 1.900e+01 1.400e+01 4.000e+00 9.000e+00 3.000e+00 3.000e+00
 3.000e+00 1.000e+00 2.000e+00 5.000e+00 6.000e+00 0.000e+00 2.000e+00
 1.000e+00 0.000e+00 1.000e+00 2.000e+00 3.000e+00 1.000e+00 1.000e+00
 1.000e+00 0.000e+00 1.000e+00 2.000e+00 0.000e+00 1.000e+00 1.000e+00
 0.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 1.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 1.000e+00
 0.000e+00 1.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00 0.000e+00
 0.000

In [41]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians
    #print("z")
    #print(z)
    #print("r")
    #print(r)
    #print("theta")
    #print(theta)

    # Calculate the cubic root of the magnitude and the angle divided by 3
    r = r ** (1/3)  # Magnitude of the cubic root
    theta = theta / 3  # Angle of the cubic root
    #print("root_magnitude")
    #print(r)
    #print("root_angle")
    #print(theta)

    return r * torch.cos(theta) + 1j * r * torch.sin(theta)  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e = e/a
    d = d/a
    c = c/a
    b = b/a
    i, lr_1 = torch.frexp(e)
    i, lr_2 = torch.frexp(d)
    i, lr_3 = torch.frexp(c)
    i, lr_4 = torch.frexp(b)
    i = torch.max(
        torch.cat((2 + 3 * torch.max(torch.cat((2 * lr_3, 2 + lr_4 + lr_2, 4 + lr_1), 
                                           dim = -1), dim = -1, keepdim = True)[0], 
                  2 * torch.max(torch.cat((1 + 3 * lr_3, 3 + lr_4 + lr_3 + lr_2, 5 + 2 * lr_4 + lr_1, 5 + 2 * lr_2, 8 + lr_3 + lr_1), 
                                        dim = -1), dim = -1, keepdim = True)[0]), 
                  dim = -1), dim = -1, keepdim = True)[0]//12
        
    a = c ** 2 - 3 * b * d + 12 * e
    e = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e
    e = torch.ldexp(e, -6*i)
    a = torch.ldexp(a, -4*i)
    #print("Delta_0")
    #print(a)
    #print("Delta_1")
    #print(e)
    #print("under sqrt")
    #print(e ** 2 - 4 * a ** 3)
    #print("sqrt")
    #print(torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64)))

    Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2) 
    #print("Q")
    #print(Q)
    
    # Calculate p and q
    e = torch.ldexp(c - 0.375 * b ** 2, -2*i)
    d = torch.ldexp((0.5*b) ** 3 - 0.5 * b * c + d, -3*i)
    #print("q")
    #print(d)
    S = torch.sqrt(-2 / 3 * e + (Q + a / Q) / 3) / 2
    a = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e + d / S)
    c = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e - d / S)
    b = torch.ldexp(-b, -i-2)

    lr_1 = torch.ldexp(torch.real(b - S + a), i)
    lr_2 = torch.ldexp(torch.real(b - S - a), i)
    lr_3 = torch.ldexp(torch.real(b + S + c), i)
    lr_4 = torch.ldexp(torch.real(b + S - c), i)
    
    return torch.cat([lr_1, lr_2, lr_3, lr_4],dim = -1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    a = c - b ** 2 / 3
    d = 2/27*b**3 - b*c/3 + d
    c = (a/3)**3 + (d/2)**2
    c = torch.sqrt(c)
    alpha = -d/2 + c
    beta = -d/2 - c
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta) + 1j * sqrt_3 * (alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta) - 1j * sqrt_3 * (alpha-beta))/2)
    return torch.cat([lr1, lr2, lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1, x2, torch.zeros_like(x1), torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.sum(x*x, dim = -1, keepdim = True)
    a_1 = torch.sum(u*x, dim = -1, keepdim = True)
    a_2 = torch.sum(v*x, dim = -1, keepdim = True)
    a_3 = torch.sum(w*x, dim = -1, keepdim = True)
    a_4 = torch.sum(w*u, dim = -1, keepdim = True)
    a_5 = torch.sum(w*v, dim = -1, keepdim = True)
    a_6 = torch.sum(w*w, dim = -1, keepdim = True)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    #p_0 = a_0
    #p_1 = a_1
    #p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = a_0 * r_1 * r_2 - 2 * a_1 * r_0 * r_2 + a_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * a_0 * r_1 * q_2 - 3 * r_0 * a_1 * q_2
    d = 2 * a_0 * q_1 * q_2 + 2 * a_0 * r_1 * a_2 - a_0 * a_1 * r_2 - r_0 * a_1 * a_2
    e = 2 * a_0 * q_1 * a_2 - a_0 * a_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), lr * (-x / a_0 + 2 * u / a_1 - v / a_2), n, lr
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x = x + v
    return x/torch.linalg.norm(x, dim = -1, keepdim = True)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A, x)
    x = update_vector(x, v)
    steps = steps_already + 1
    cond = (1 - e > threshold) & (steps < steps_max)
    while cond.any():
        steps = torch.where(cond, steps + 1, steps)
        e,v,i,l = optimal_lr(A, x)
        x = torch.where(cond, update_vector(x, v), x)
        cond = (1 - e > threshold) & (steps < steps_max)
    v=[]
    f=A(x)
    #print_gpu_memory()
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (
            f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x
            )/torch.linalg.norm(f, dim=-1).unsqueeze(-1), steps

import torch.nn as nn

def add(x,y):
    return x+y

class square(nn.Module):
    def __init__(self, g):
        super(square, self).__init__()
        self.g=g
    
    def forward(self, x):
        return self.g(self.g(x))

class minus_f(nn.Module):
    def __init__(self, g, f):
        super(minus_f, self).__init__()
        self.g=g
        self.f=f

    def forward(self, x):
        return self.g(x) - self.f*x

class t_minus(nn.Module):
    def __init__(self, g, t):
        super(t_minus, self).__init__()
        self.g=g
        self.t=t

    def forward(self, x):
        return self.t*x - self.g(x)


class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=((torch.rand(batch_size,1).cuda())*torch.max(torch.cat([self.f**2, (1-self.f)**2], dim=1), dim=1, keepdim = True)[0]).cuda()
        self.z=torch.max(torch.cat([self.t**2, (self.t-self.f**2)**2, (self.t-(1-self.f)**2)**2], dim=1), dim=1, keepdim = True)[0]
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
        self.minus_f = minus_f(self.A, self.f)
        self.square_1 = square(self.minus_f)
        self.t_minus = t_minus(self.square_1, self.t)
        self.square_2 = square(self.t_minus)
    
    def forward_1(self,x):
        return x - self.square_2(x)/self.z

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=16384
matrix_size=128
A=torch.diag(1/(1+torch.tensor(range(matrix_size), dtype=torch.float64)))

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float64)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

probs = 0.5*torch.ones((batch_size, 1)).cuda()
delta = torch.tensor(1e-6).cuda()

class truncated(nn.Module):
    def __init__(self, f, x):
        super(truncated, self).__init__()
        
        y = f(x)
        xx = torch.sum(x**2, dim = -1, keepdim = True)
        xAx = torch.sum(x*y, dim = -1, keepdim = True)
        xAAx = torch.sum(y**2, dim = -1, keepdim = True)
        mean = xAx/xx
        sigma = torch.sqrt(xAAx*xx - xAx**2)/xx
        self.mean = mean + (-1)**torch.bernoulli(probs) * sigma
        self.sigma = torch.sqrt(-2*sigma*torch.log(delta))
        self.lambda_min = self.mean - self.sigma
        self.lambda_max = self.mean + self.sigma
        self.f = f

    def forward(self, x):
        y = self.f(x) - self.mean * x
        y = self.f(y) - self.mean * y
        return x - y / torch.min(torch.cat((
            (self.lambda_min - self.mean)**2, 
            (self.lambda_max - self.mean)**2), 
                dim = -1), dim = -1, keepdim = True)[0]

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
stepses=[]
with torch.cuda.device(0):
    for i in range(1):
        x = torch.empty((batch_size, matrix_size), dtype=torch.float64).normal_(mean = 0, std = 1).cuda()
        steps = torch.zeros((batch_size, 1)).cuda()
        stepses.append(torch.zeros((batch_size, 1)).cuda())
        t[0] = x/torch.linalg.norm(x, dim = -1, keepdim = True)
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = modified(expr_1, f, batch_size)
        s = grad_ascend_lr(g, t[0], 1e-30, steps, 2000)
        t[0] = torch.where(cond, s[0], t[0])
        steps = s[-1]
        stepses[i] = torch.where(cond, s[-1], stepses[i])
        s = [1]
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = truncated(f, t[0])
        while torch.any((steps < 2000) & cond):
            s = grad_ascend_lr(g, t[0], 1e-30, steps, 2000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = s[-1]
            stepses[i] = torch.where(cond, s[-1], stepses[i])
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
            g = truncated(f, t[0])
        result+= torch.sum(torch.where(torch.abs(h(t[0])) >= 0.9, 1, 0), dim = 0).cpu().numpy()
        print(i, list(np.where(result > 0)[0]), np.sum(result) / ((i + 1) * batch_size), end='\r')


t


  Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2)


0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 56, 57, 58, 60, 61, 63, 70, 71, 75, 127] 0.62646484375

In [43]:
np.sort(steps.cpu().numpy()[np.where(torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=1,keepdim=True).cpu().numpy()>0)])[:1000]

array([   7.,    7.,    7.,    7.,    8.,    8.,    8.,    8.,    8.,
          8.,    8.,    8.,    8.,    8.,    8.,    8.,    9.,    9.,
          9.,    9.,    9.,    9.,    9.,   10.,   10.,   10.,   10.,
         10.,   10.,   10.,   10.,   10.,   10.,   10.,   11.,   11.,
         11.,   11.,   11.,   11.,   11.,   11.,   11.,   12.,   12.,
         12.,   12.,   12.,   12.,   12.,   12.,   12.,   12.,   12.,
         12.,   12.,   13.,   13.,   13.,   13.,   13.,   13.,   13.,
         13.,   13.,   13.,   14.,   14.,   14.,   14.,   14.,   14.,
         14.,   14.,   14.,   15.,   15.,   15.,   15.,   15.,   15.,
         15.,   15.,   15.,   16.,   16.,   16.,   16.,   16.,   16.,
         16.,   16.,   16.,   16.,   17.,   17.,   17.,   17.,   17.,
         17.,   17.,   17.,   18.,   18.,   18.,   18.,   18.,   18.,
         18.,   19.,   19.,   19.,   19.,   19.,   19.,   19.,   20.,
         20.,   20.,   20.,   20.,   20.,   20.,   21.,   21.,   21.,
         21.,   21.,

In [44]:
np.sum(result)

10264.0

In [45]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians
    #print("z")
    #print(z)
    #print("r")
    #print(r)
    #print("theta")
    #print(theta)

    # Calculate the cubic root of the magnitude and the angle divided by 3
    r = r ** (1/3)  # Magnitude of the cubic root
    theta = theta / 3  # Angle of the cubic root
    #print("root_magnitude")
    #print(r)
    #print("root_angle")
    #print(theta)

    return r * torch.cos(theta) + 1j * r * torch.sin(theta)  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e = e/a
    d = d/a
    c = c/a
    b = b/a
    i, lr_1 = torch.frexp(e)
    i, lr_2 = torch.frexp(d)
    i, lr_3 = torch.frexp(c)
    i, lr_4 = torch.frexp(b)
    i = torch.max(
        torch.cat((2 + 3 * torch.max(torch.cat((2 * lr_3, 2 + lr_4 + lr_2, 4 + lr_1), 
                                           dim = -1), dim = -1, keepdim = True)[0], 
                  2 * torch.max(torch.cat((1 + 3 * lr_3, 3 + lr_4 + lr_3 + lr_2, 5 + 2 * lr_4 + lr_1, 5 + 2 * lr_2, 8 + lr_3 + lr_1), 
                                        dim = -1), dim = -1, keepdim = True)[0]), 
                  dim = -1), dim = -1, keepdim = True)[0]//12
        
    a = c ** 2 - 3 * b * d + 12 * e
    e = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e
    e = torch.ldexp(e, -6*i)
    a = torch.ldexp(a, -4*i)
    #print("Delta_0")
    #print(a)
    #print("Delta_1")
    #print(e)
    #print("under sqrt")
    #print(e ** 2 - 4 * a ** 3)
    #print("sqrt")
    #print(torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64)))

    Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2) 
    #print("Q")
    #print(Q)
    
    # Calculate p and q
    e = torch.ldexp(c - 0.375 * b ** 2, -2*i)
    d = torch.ldexp((0.5*b) ** 3 - 0.5 * b * c + d, -3*i)
    #print("q")
    #print(d)
    S = torch.sqrt(-2 / 3 * e + (Q + a / Q) / 3) / 2
    a = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e + d / S)
    c = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e - d / S)
    b = torch.ldexp(-b, -i-2)

    lr_1 = torch.ldexp(torch.real(b - S + a), i)
    lr_2 = torch.ldexp(torch.real(b - S - a), i)
    lr_3 = torch.ldexp(torch.real(b + S + c), i)
    lr_4 = torch.ldexp(torch.real(b + S - c), i)
    
    return torch.cat([lr_1, lr_2, lr_3, lr_4],dim = -1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    a = c - b ** 2 / 3
    d = 2/27*b**3 - b*c/3 + d
    c = (a/3)**3 + (d/2)**2
    c = torch.sqrt(c)
    alpha = -d/2 + c
    beta = -d/2 - c
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta) + 1j * sqrt_3 * (alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta) - 1j * sqrt_3 * (alpha-beta))/2)
    return torch.cat([lr1, lr2, lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1, x2, torch.zeros_like(x1), torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.sum(x*x, dim = -1, keepdim = True)
    a_1 = torch.sum(u*x, dim = -1, keepdim = True)
    a_2 = torch.sum(v*x, dim = -1, keepdim = True)
    a_3 = torch.sum(w*x, dim = -1, keepdim = True)
    a_4 = torch.sum(w*u, dim = -1, keepdim = True)
    a_5 = torch.sum(w*v, dim = -1, keepdim = True)
    a_6 = torch.sum(w*w, dim = -1, keepdim = True)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    #p_0 = a_0
    #p_1 = a_1
    #p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = a_0 * r_1 * r_2 - 2 * a_1 * r_0 * r_2 + a_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * a_0 * r_1 * q_2 - 3 * r_0 * a_1 * q_2
    d = 2 * a_0 * q_1 * q_2 + 2 * a_0 * r_1 * a_2 - a_0 * a_1 * r_2 - r_0 * a_1 * a_2
    e = 2 * a_0 * q_1 * a_2 - a_0 * a_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), lr * (-x / a_0 + 2 * u / a_1 - v / a_2), n, lr
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x = x + v
    return x/torch.linalg.norm(x, dim = -1, keepdim = True)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A, x)
    x = update_vector(x, v)
    steps = steps_already + 1
    cond = (1 - e > threshold) & (steps < steps_max)
    while cond.any():
        steps = torch.where(cond, steps + 1, steps)
        e,v,i,l = optimal_lr(A, x)
        x = torch.where(cond, update_vector(x, v), x)
        cond = (1 - e > threshold) & (steps < steps_max)
    v=[]
    f=A(x)
    #print_gpu_memory()
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (
            f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x
            )/torch.linalg.norm(f, dim=-1).unsqueeze(-1), steps

import torch.nn as nn

def add(x,y):
    return x+y

class square(nn.Module):
    def __init__(self, g):
        super(square, self).__init__()
        self.g=g
    
    def forward(self, x):
        return self.g(self.g(x))

class minus_f(nn.Module):
    def __init__(self, g, f):
        super(minus_f, self).__init__()
        self.g=g
        self.f=f

    def forward(self, x):
        return self.g(x) - self.f*x

class t_minus(nn.Module):
    def __init__(self, g, t):
        super(t_minus, self).__init__()
        self.g=g
        self.t=t

    def forward(self, x):
        return self.t*x - self.g(x)


class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=((torch.rand(batch_size,1).cuda())*torch.max(torch.cat([self.f**2, (1-self.f)**2], dim=1), dim=1, keepdim = True)[0]).cuda()
        self.z=torch.max(torch.cat([self.t**2, (self.t-self.f**2)**2, (self.t-(1-self.f)**2)**2], dim=1), dim=1, keepdim = True)[0]
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
        self.minus_f = minus_f(self.A, self.f)
        self.square_1 = square(self.minus_f)
        self.t_minus = t_minus(self.square_1, self.t)
        self.square_2 = square(self.t_minus)
    
    def forward_1(self,x):
        return x - self.square_2(x)/self.z

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=16384
matrix_size=256
A=torch.diag(1/(1+torch.tensor(range(matrix_size), dtype=torch.float64)))

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float64)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

probs = 0.5*torch.ones((batch_size, 1)).cuda()
delta = torch.tensor(1e-6).cuda()

class truncated(nn.Module):
    def __init__(self, f, x):
        super(truncated, self).__init__()
        
        y = f(x)
        xx = torch.sum(x**2, dim = -1, keepdim = True)
        xAx = torch.sum(x*y, dim = -1, keepdim = True)
        xAAx = torch.sum(y**2, dim = -1, keepdim = True)
        mean = xAx/xx
        sigma = torch.sqrt(xAAx*xx - xAx**2)/xx
        self.mean = mean + (-1)**torch.bernoulli(probs) * sigma
        self.sigma = torch.sqrt(-2*sigma*torch.log(delta))
        self.lambda_min = self.mean - self.sigma
        self.lambda_max = self.mean + self.sigma
        self.f = f

    def forward(self, x):
        y = self.f(x) - self.mean * x
        y = self.f(y) - self.mean * y
        return x - y / torch.min(torch.cat((
            (self.lambda_min - self.mean)**2, 
            (self.lambda_max - self.mean)**2), 
                dim = -1), dim = -1, keepdim = True)[0]

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
stepses=[]
with torch.cuda.device(0):
    for i in range(1):
        x = torch.empty((batch_size, matrix_size), dtype=torch.float64).normal_(mean = 0, std = 1).cuda()
        steps = torch.zeros((batch_size, 1)).cuda()
        stepses.append(torch.zeros((batch_size, 1)).cuda())
        t[0] = x/torch.linalg.norm(x, dim = -1, keepdim = True)
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = modified(expr_1, f, batch_size)
        s = grad_ascend_lr(g, t[0], 1e-30, steps, 2000)
        t[0] = torch.where(cond, s[0], t[0])
        steps = s[-1]
        stepses[i] = torch.where(cond, s[-1], stepses[i])
        s = [1]
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = truncated(f, t[0])
        while torch.any((steps < 2000) & cond):
            s = grad_ascend_lr(g, t[0], 1e-30, steps, 2000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = s[-1]
            stepses[i] = torch.where(cond, s[-1], stepses[i])
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
            g = truncated(f, t[0])
        result+= torch.sum(torch.where(torch.abs(h(t[0])) >= 0.9, 1, 0), dim = 0).cpu().numpy()
        print(i, list(np.where(result > 0)[0]), np.sum(result) / ((i + 1) * batch_size), end='\r')


t


  Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2)


0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 255] 0.61627197265625

In [46]:
print(np.sum(result))

10097.0


In [47]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians
    #print("z")
    #print(z)
    #print("r")
    #print(r)
    #print("theta")
    #print(theta)

    # Calculate the cubic root of the magnitude and the angle divided by 3
    r = r ** (1/3)  # Magnitude of the cubic root
    theta = theta / 3  # Angle of the cubic root
    #print("root_magnitude")
    #print(r)
    #print("root_angle")
    #print(theta)

    return r * torch.cos(theta) + 1j * r * torch.sin(theta)  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e = e/a
    d = d/a
    c = c/a
    b = b/a
    i, lr_1 = torch.frexp(e)
    i, lr_2 = torch.frexp(d)
    i, lr_3 = torch.frexp(c)
    i, lr_4 = torch.frexp(b)
    i = torch.max(
        torch.cat((2 + 3 * torch.max(torch.cat((2 * lr_3, 2 + lr_4 + lr_2, 4 + lr_1), 
                                           dim = -1), dim = -1, keepdim = True)[0], 
                  2 * torch.max(torch.cat((1 + 3 * lr_3, 3 + lr_4 + lr_3 + lr_2, 5 + 2 * lr_4 + lr_1, 5 + 2 * lr_2, 8 + lr_3 + lr_1), 
                                        dim = -1), dim = -1, keepdim = True)[0]), 
                  dim = -1), dim = -1, keepdim = True)[0]//12
        
    a = c ** 2 - 3 * b * d + 12 * e
    e = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e
    e = torch.ldexp(e, -6*i)
    a = torch.ldexp(a, -4*i)
    #print("Delta_0")
    #print(a)
    #print("Delta_1")
    #print(e)
    #print("under sqrt")
    #print(e ** 2 - 4 * a ** 3)
    #print("sqrt")
    #print(torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64)))

    Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2) 
    #print("Q")
    #print(Q)
    
    # Calculate p and q
    e = torch.ldexp(c - 0.375 * b ** 2, -2*i)
    d = torch.ldexp((0.5*b) ** 3 - 0.5 * b * c + d, -3*i)
    #print("q")
    #print(d)
    S = torch.sqrt(-2 / 3 * e + (Q + a / Q) / 3) / 2
    a = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e + d / S)
    c = 0.5*torch.sqrt(-4 * S ** 2 - 2 * e - d / S)
    b = torch.ldexp(-b, -i-2)

    lr_1 = torch.ldexp(torch.real(b - S + a), i)
    lr_2 = torch.ldexp(torch.real(b - S - a), i)
    lr_3 = torch.ldexp(torch.real(b + S + c), i)
    lr_4 = torch.ldexp(torch.real(b + S - c), i)
    
    return torch.cat([lr_1, lr_2, lr_3, lr_4],dim = -1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    a = c - b ** 2 / 3
    d = 2/27*b**3 - b*c/3 + d
    c = (a/3)**3 + (d/2)**2
    c = torch.sqrt(c)
    alpha = -d/2 + c
    beta = -d/2 - c
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta) + 1j * sqrt_3 * (alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta) - 1j * sqrt_3 * (alpha-beta))/2)
    return torch.cat([lr1, lr2, lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1, x2, torch.zeros_like(x1), torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.sum(x*x, dim = -1, keepdim = True)
    a_1 = torch.sum(u*x, dim = -1, keepdim = True)
    a_2 = torch.sum(v*x, dim = -1, keepdim = True)
    a_3 = torch.sum(w*x, dim = -1, keepdim = True)
    a_4 = torch.sum(w*u, dim = -1, keepdim = True)
    a_5 = torch.sum(w*v, dim = -1, keepdim = True)
    a_6 = torch.sum(w*w, dim = -1, keepdim = True)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    #p_0 = a_0
    #p_1 = a_1
    #p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = a_0 * r_1 * r_2 - 2 * a_1 * r_0 * r_2 + a_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * a_0 * r_1 * q_2 - 3 * r_0 * a_1 * q_2
    d = 2 * a_0 * q_1 * q_2 + 2 * a_0 * r_1 * a_2 - a_0 * a_1 * r_2 - r_0 * a_1 * a_2
    e = 2 * a_0 * q_1 * a_2 - a_0 * a_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), lr * (-x / a_0 + 2 * u / a_1 - v / a_2), n, lr
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x = x + v
    return x/torch.linalg.norm(x, dim = -1, keepdim = True)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A, x)
    x = update_vector(x, v)
    steps = steps_already + 1
    cond = (1 - e > threshold) & (steps < steps_max)
    while cond.any():
        steps = torch.where(cond, steps + 1, steps)
        e,v,i,l = optimal_lr(A, x)
        x = torch.where(cond, update_vector(x, v), x)
        cond = (1 - e > threshold) & (steps < steps_max)
    v=[]
    f=A(x)
    #print_gpu_memory()
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (
            f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x
            )/torch.linalg.norm(f, dim=-1).unsqueeze(-1), steps

import torch.nn as nn

def add(x,y):
    return x+y

class square(nn.Module):
    def __init__(self, g):
        super(square, self).__init__()
        self.g=g
    
    def forward(self, x):
        return self.g(self.g(x))

class minus_f(nn.Module):
    def __init__(self, g, f):
        super(minus_f, self).__init__()
        self.g=g
        self.f=f

    def forward(self, x):
        return self.g(x) - self.f*x

class t_minus(nn.Module):
    def __init__(self, g, t):
        super(t_minus, self).__init__()
        self.g=g
        self.t=t

    def forward(self, x):
        return self.t*x - self.g(x)


class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=((torch.rand(batch_size,1).cuda())*torch.max(torch.cat([self.f**2, (1-self.f)**2], dim=1), dim=1, keepdim = True)[0]).cuda()
        self.z=torch.max(torch.cat([self.t**2, (self.t-self.f**2)**2, (self.t-(1-self.f)**2)**2], dim=1), dim=1, keepdim = True)[0]
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
        self.minus_f = minus_f(self.A, self.f)
        self.square_1 = square(self.minus_f)
        self.t_minus = t_minus(self.square_1, self.t)
        self.square_2 = square(self.t_minus)
    
    def forward_1(self,x):
        return x - self.square_2(x)/self.z

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=16384
matrix_size=32
A=torch.diag(torch.flatten(1/(1+torch.tensor(range(matrix_size), dtype=torch.float64)).unsqueeze(0)+
                           1/(1+torch.tensor(range(matrix_size), dtype=torch.float64)).unsqueeze(1)))

matrix_size=matrix_size**2
import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float64)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

probs = 0.5*torch.ones((batch_size, 1)).cuda()
delta = torch.tensor(1e-6).cuda()

class truncated(nn.Module):
    def __init__(self, f, x):
        super(truncated, self).__init__()
        
        y = f(x)
        xx = torch.sum(x**2, dim = -1, keepdim = True)
        xAx = torch.sum(x*y, dim = -1, keepdim = True)
        xAAx = torch.sum(y**2, dim = -1, keepdim = True)
        mean = xAx/xx
        sigma = torch.sqrt(xAAx*xx - xAx**2)/xx
        self.mean = mean + (-1)**torch.bernoulli(probs) * sigma
        self.sigma = torch.sqrt(-2*sigma*torch.log(delta))
        self.lambda_min = self.mean - self.sigma
        self.lambda_max = self.mean + self.sigma
        self.f = f

    def forward(self, x):
        y = self.f(x) - self.mean * x
        y = self.f(y) - self.mean * y
        return x - y / torch.min(torch.cat((
            (self.lambda_min - self.mean)**2, 
            (self.lambda_max - self.mean)**2), 
                dim = -1), dim = -1, keepdim = True)[0]

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
stepses=[]
with torch.cuda.device(0):
    for i in range(1):
        x = torch.empty((batch_size, matrix_size), dtype=torch.float64).normal_(mean = 0, std = 1).cuda()
        steps = torch.zeros((batch_size, 1)).cuda()
        stepses.append(torch.zeros((batch_size, 1)).cuda())
        t[0] = x/torch.linalg.norm(x, dim = -1, keepdim = True)
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = modified(expr_1, f, batch_size)
        s = grad_ascend_lr(g, t[0], 1e-30, steps, 2000)
        t[0] = torch.where(cond, s[0], t[0])
        steps = s[-1]
        stepses[i] = torch.where(cond, s[-1], stepses[i])
        s = [1]
        cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
        g = truncated(f, t[0])
        while torch.any((steps < 2000) & cond):
            s = grad_ascend_lr(g, t[0], 1e-30, steps, 2000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = s[-1]
            stepses[i] = torch.where(cond, s[-1], stepses[i])
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim = -1, keepdim = True)[0] < 0.9)
            g = truncated(f, t[0])
        result+= torch.sum(torch.where(torch.abs(h(t[0])) >= 0.9, 1, 0), dim = 0).cpu().numpy()
        print(i, list(np.where(result > 0)[0]), np.sum(result) / ((i + 1) * batch_size), end='\r')


t


  Q = complex_cbrt((e + torch.sqrt(torch.tensor(e ** 2 - 4 * a ** 3, dtype=torch.complex64))) / 2)


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 5.68 GiB of which 1.62 GiB is free. Process 6834 has 1.78 GiB memory in use. Including non-PyTorch memory, this process has 2.25 GiB memory in use. Of the allocated memory 602.42 MiB is allocated by PyTorch, and 1.53 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)