In [None]:
import torch

device=torch.device("cuda:0")

In [13]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians

    # Calculate the cubic root of the magnitude and the angle divided by 3
    root_magnitude = r ** (1/3)  # Magnitude of the cubic root
    root_angle = theta / 3  # Angle of the cubic root

    # Convert back to rectangular form (Cartesian coordinates)
    root_real = root_magnitude * torch.cos(root_angle)
    root_imag = root_magnitude * torch.sin(root_angle)

    return root_real + 1j * root_imag  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e=e/a
    d=d/a
    c=c/a
    b=b/a
    Delta_0 = c ** 2 - 3 * b * d + 12 * e
    Delta_1 = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e

    Q = complex_cbrt((Delta_1 + torch.sqrt(torch.tensor(Delta_1 ** 2 - 4 * Delta_0 ** 3, dtype=torch.complex64))) / 2)
    
    # Calculate p and q
    p = c - 0.375 * b ** 2
    q = (0.5*b) ** 3 - 0.5 * b * c + d    
    S = torch.sqrt(-2 / 3 * p + (Q + Delta_0 / Q) / 3) / 2
    S_1 = 0.5*torch.sqrt(-4 * S ** 2 - 2 * p + q / S)
    S_2 = 0.5*torch.sqrt(-4 * S ** 2 - 2 * p - q / S)
    b=-0.25*b

    lr_1 = torch.real(b - S + S_1)
    lr_2 = torch.real(b - S - S_1)
    lr_3 = torch.real(b + S + S_2)
    lr_4 = torch.real(b + S - S_2)
    
    return torch.cat([lr_1,lr_2,lr_3,lr_4],dim=-1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    p = c - b ** 2 / 3
    q = 2/27*b**3 - b*c/3 + d
    Q = (p/3)**3 + (q/2)**2
    S = torch.sqrt(Q)
    alpha = -q/2 + S
    beta = -q/2 - S
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta)+1j*sqrt_3*(alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta)-1j*sqrt_3*(alpha-beta))/2)
    return torch.cat([lr1,lr2,lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1,x2,torch.zeros_like(x1),torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.einsum('ij,ij->i', x, x).unsqueeze(-1)
    a_1 = torch.einsum('ij,ij->i', u, x).unsqueeze(-1)
    a_2 = torch.einsum('ij,ij->i', v, x).unsqueeze(-1)
    a_3 = torch.einsum('ij,ij->i', w, x).unsqueeze(-1)
    a_4 = torch.einsum('ij,ij->i', w, u).unsqueeze(-1)
    a_5 = torch.einsum('ij,ij->i', w, v).unsqueeze(-1)
    a_6 = torch.einsum('ij,ij->i', w, w).unsqueeze(-1)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    p_0 = a_0
    p_1 = a_1
    p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = p_0 * r_1 * r_2 - 2 * p_1 * r_0 * r_2 + p_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * p_0 * r_1 * q_2 - 3 * r_0 * p_1 * q_2
    d = 2 * p_0 * q_1 * q_2 + 2 * p_0 * r_1 * p_2 - p_0 * p_1 * r_2 - r_0 * p_1 * p_2
    e = 2 * p_0 * q_1 * p_2 - p_0 * p_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x=x+v
    return x/torch.linalg.norm(x, dim=-1).unsqueeze(-1)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A,x)
    x = update_vector(x, v)
    steps = steps_already+1
    cond = (e<threshold) & (steps<steps_max)
    while cond.any():
        x = torch.where(cond, update_vector(x, v), x)
        steps = torch.where(cond, steps+1, steps)
        e,v,i,l=optimal_lr(A,x)
        cond=(e<threshold)&(steps<steps_max)
    f=A(x)
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x)/torch.linalg.norm(f, dim=-1).unsqueeze(-1),steps

import torch.nn as nn

def add(x,y):
    return x+y

class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=(torch.rand(batch_size,1).cuda()*torch.max(torch.cat([self.f**2,(1-self.f)**2],dim=1),dim=1)[0].unsqueeze(-1)).cuda()
        self.z=torch.max(torch.cat([self.t**2,(self.t-self.f**2)**2,(self.t-(1-self.f)**2)**2], dim=1), dim=1)[0].unsqueeze(-1)
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
    def forward_1(self,x):
            y=self.A(x)-self.f*x
            y=self.A(y)-self.f*y
            y=self.t*x-y
            z=self.A(y)-self.f*y
            z=self.A(z)-self.f*z
            y=self.t*y-z
            return self.z*x-y

        
    def forward(self,x):
        
        return self.forward_1(x) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=10000
matrix_size=1000
A=torch.diag(1/(torch.tensor(range(matrix_size), dtype=torch.float32)+1))

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float32)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
with torch.cuda.device(0):
    for i in range(10000):
        x = torch.empty((batch_size,matrix_size)).normal_(mean=0,std=1).cuda()
        steps = torch.zeros((batch_size,1)).cuda()
        t[0] = x/torch.linalg.norm(x, dim=-1, keepdim=True)
        cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        while torch.any((steps<2000) & cond):
            g = modified(expr_1,f,batch_size)
            s = grad_ascend_lr(g,t[0],0.9999,steps,2000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = s[-1]
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        result+= torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=0).cpu().numpy()
        print(i, list(np.where(result>0)[0]), end='\r')
        t = [1]



t


  Q = complex_cbrt((Delta_1 + torch.sqrt(torch.tensor(Delta_1 ** 2 - 4 * Delta_0 ** 3, dtype=torch.complex64))) / 2)


6 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 29, 35]

KeyboardInterrupt: 

In [14]:
result[np.where(result>0)[0]]

array([1.6372e+04, 8.2040e+03, 4.5570e+03, 3.1630e+03, 1.9740e+03,
       1.1010e+03, 5.7800e+02, 2.5800e+02, 1.4500e+02, 8.7000e+01,
       6.5000e+01, 3.2000e+01, 2.7000e+01, 1.9000e+01, 8.0000e+00,
       1.0000e+01, 4.0000e+00, 4.0000e+00, 2.0000e+00, 3.0000e+00,
       3.0000e+00, 3.0000e+00, 3.0000e+00, 3.0000e+00, 1.0000e+00,
       1.0000e+00, 1.0000e+00])

In [12]:
############################################################3
##############################################################
#############################################################

tensor(71181., device='cuda:0')


In [None]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians

    # Calculate the cubic root of the magnitude and the angle divided by 3
    root_magnitude = r ** (1/3)  # Magnitude of the cubic root
    root_angle = theta / 3  # Angle of the cubic root

    # Convert back to rectangular form (Cartesian coordinates)
    root_real = root_magnitude * torch.cos(root_angle)
    root_imag = root_magnitude * torch.sin(root_angle)

    return root_real + 1j * root_imag  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e=e/a
    d=d/a
    c=c/a
    b=b/a
    Delta_0 = c ** 2 - 3 * b * d + 12 * e
    Delta_1 = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e

    Q = complex_cbrt((Delta_1 + torch.sqrt(torch.tensor(Delta_1 ** 2 - 4 * Delta_0 ** 3, dtype=torch.complex64))) / 2)
    
    # Calculate p and q
    p = c - 0.375 * b ** 2
    q = (0.5*b) ** 3 - 0.5 * b * c + d    
    S = torch.sqrt(-2 / 3 * p + (Q + Delta_0 / Q) / 3) / 2
    S_1 = 0.5*torch.sqrt(-4 * S ** 2 - 2 * p + q / S)
    S_2 = 0.5*torch.sqrt(-4 * S ** 2 - 2 * p - q / S)
    b=-0.25*b

    lr_1 = torch.real(b - S + S_1)
    lr_2 = torch.real(b - S - S_1)
    lr_3 = torch.real(b + S + S_2)
    lr_4 = torch.real(b + S - S_2)
    
    return torch.cat([lr_1,lr_2,lr_3,lr_4],dim=-1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    p = c - b ** 2 / 3
    q = 2/27*b**3 - b*c/3 + d
    Q = (p/3)**3 + (q/2)**2
    S = torch.sqrt(Q)
    alpha = -q/2 + S
    beta = -q/2 - S
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta)+1j*sqrt_3*(alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta)-1j*sqrt_3*(alpha-beta))/2)
    return torch.cat([lr1,lr2,lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1,x2,torch.zeros_like(x1),torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.einsum('ij,ij->i', x, x).unsqueeze(-1)
    a_1 = torch.einsum('ij,ij->i', u, x).unsqueeze(-1)
    a_2 = torch.einsum('ij,ij->i', v, x).unsqueeze(-1)
    a_3 = torch.einsum('ij,ij->i', w, x).unsqueeze(-1)
    a_4 = torch.einsum('ij,ij->i', w, u).unsqueeze(-1)
    a_5 = torch.einsum('ij,ij->i', w, v).unsqueeze(-1)
    a_6 = torch.einsum('ij,ij->i', w, w).unsqueeze(-1)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    p_0 = a_0
    p_1 = a_1
    p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = p_0 * r_1 * r_2 - 2 * p_1 * r_0 * r_2 + p_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * p_0 * r_1 * q_2 - 3 * r_0 * p_1 * q_2
    d = 2 * p_0 * q_1 * q_2 + 2 * p_0 * r_1 * p_2 - p_0 * p_1 * r_2 - r_0 * p_1 * p_2
    e = 2 * p_0 * q_1 * p_2 - p_0 * p_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x=x+v
    return x/torch.linalg.norm(x, dim=-1).unsqueeze(-1)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A,x)
    x = update_vector(x, v)
    steps = steps_already+1
    cond = (e<threshold) & (steps<steps_max)
    while cond.any():
        x = torch.where(cond, update_vector(x, v), x)
        steps = torch.where(cond, steps+1, steps)
        e,v,i,l=optimal_lr(A,x)
        cond=(e<threshold)&(steps<steps_max)
    f=A(x)
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x)/torch.linalg.norm(f, dim=-1).unsqueeze(-1),steps

import torch.nn as nn

def add(x,y):
    return x+y

class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=(torch.rand(batch_size,1).cuda()*torch.max(torch.cat([self.f**2,(1-self.f)**2],dim=1),dim=1)[0].unsqueeze(-1)).cuda()
        self.z=torch.max(torch.cat([self.t**2,(self.t-self.f**2)**2,(self.t-(1-self.f)**2)**2], dim=1), dim=1)[0].unsqueeze(-1)
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
    def forward_1(self,x):
            y=self.A(x)-self.f*x
            y=self.A(y)-self.f*y
            y=self.t*x-y
            z=self.A(y)-self.f*y
            z=self.A(z)-self.f*z
            y=self.t*y-z
            return self.z*x-y

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=10000
matrix_size=1000
A=torch.diag(1/(torch.tensor(range(matrix_size), dtype=torch.float32)+1))

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float32)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
with torch.cuda.device(0):
    for i in range(10000):
        x = torch.empty((batch_size,matrix_size)).normal_(mean=0,std=1).cuda()
        steps = torch.zeros((batch_size,1)).cuda()
        t[0] = x/torch.linalg.norm(x, dim=-1, keepdim=True)
        cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        while torch.any((steps<2000) & cond):
            g = modified(expr_1,f,batch_size)
            s = grad_ascend_lr(g,t[0],0.9999,steps,2000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = s[-1]
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        result+= torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=0).cpu().numpy()
        print(i, list(np.where(result>0)[0]), end='\r')
        t = [1]



t


  Q = complex_cbrt((Delta_1 + torch.sqrt(torch.tensor(Delta_1 ** 2 - 4 * Delta_0 ** 3, dtype=torch.complex64))) / 2)


0 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

In [None]:
print(torch.min(steps), torch.max(steps), torch.sum(steps))

In [2]:
print(np.where(result>0))

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 16, 17,
       21, 22]),)


In [None]:
def grad_ascend_lr_transform(C,x,j,steps_already):
    f=jra.uniform(key = jra.PRNGKey(2*j))
    B=C-f*I
    B=B@B
    t=jra.uniform(key = jra.PRNGKey(2*j+1))*f**2
    B=t*I-B
    B=B@B
    z=max(t**2, (t-f**2)**2, (t-(1-f)**2)**2)
    B=z*I-B
    
    return grad_ascend_lr(B,x, steps_already)


In [56]:
x=torch.empty((20,10)).normal_(mean=0,std=1)
print(x.shape)

torch.Size([20, 10])


In [None]:
###################################################
#######################################################################################################
###########################################################################################################
###############################################################################

In [2]:
import torch
import sympy
from sympy import Symbol, Mul, Pow, Add

def print_gpu_memory():
    """Prints the current allocated, reserved, and free memory on the GPU."""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        allocated_memory = torch.cuda.memory_allocated(device) / (1024 ** 2)  # Convert to MB
        reserved_memory = torch.cuda.memory_reserved(device) / (1024 ** 2)  # Convert to MB
        free_memory = torch.cuda.get_device_properties(device).total_memory - torch.cuda.memory_allocated(device)  # Free memory in reserved area
        free_memory_mb = free_memory / (1024 ** 2)  # Convert to MB

        print(f"Allocated memory: {allocated_memory:.2f} MB")
        print(f"Reserved memory: {reserved_memory:.2f} MB")
        print(f"Free memory: {free_memory_mb:.2f} MB")
    else:
        print("CUDA is not available. No GPU detected.")

t=sympy.Symbol('t')
sqrt_3=torch.sqrt(torch.tensor([3.])).cuda()


def complex_cbrt(z):
    """Calculate the cubic root of a complex number."""
    # Calculate the magnitude and angle of the complex number
    r = z.abs()  # Magnitude
    theta = torch.angle(z)  # Angle in radians

    # Calculate the cubic root of the magnitude and the angle divided by 3
    root_magnitude = r ** (1/3)  # Magnitude of the cubic root
    root_angle = theta / 3  # Angle of the cubic root

    # Convert back to rectangular form (Cartesian coordinates)
    root_real = root_magnitude * torch.cos(root_angle)
    root_imag = root_magnitude * torch.sin(root_angle)

    return root_real + 1j * root_imag  # Return as a complex number

def quartic_solver(a,b,c,d,e):
    e=e/a
    d=d/a
    c=c/a
    b=b/a
    Delta_0 = c ** 2 - 3 * b * d + 12 * e
    Delta_1 = 2 * c ** 3 - 9 * b * c * d + 27 * b ** 2 * e + 27 * d ** 2 - 72 * c * e

    Q = complex_cbrt((Delta_1 + torch.sqrt(torch.tensor(Delta_1 ** 2 - 4 * Delta_0 ** 3, dtype=torch.complex64))) / 2)
    
    # Calculate p and q
    p = c - 0.375 * b ** 2
    q = (0.5*b) ** 3 - 0.5 * b * c + d    
    S = torch.sqrt(-2 / 3 * p + (Q + Delta_0 / Q) / 3) / 2
    S_1 = 0.5*torch.sqrt(-4 * S ** 2 - 2 * p + q / S)
    S_2 = 0.5*torch.sqrt(-4 * S ** 2 - 2 * p - q / S)
    b=-0.25*b

    lr_1 = torch.real(b - S + S_1)
    lr_2 = torch.real(b - S - S_1)
    lr_3 = torch.real(b + S + S_2)
    lr_4 = torch.real(b + S - S_2)
    
    return torch.cat([lr_1,lr_2,lr_3,lr_4],dim=-1)

def cubic_solver(a,b,c,d):
    d=d/a
    c=c/a
    b=b/a
    p = c - b ** 2 / 3
    q = 2/27*b**3 - b*c/3 + d
    Q = (p/3)**3 + (q/2)**2
    S = torch.sqrt(Q)
    alpha = -q/2 + S
    beta = -q/2 - S
    alpha = complex_cbrt(alpha)
    beta = complex_cbrt(beta)
    
    lr1 = torch.real(alpha+beta)
    lr2 = torch.real((-(alpha+beta)+1j*sqrt_3*(alpha-beta))/2)
    lr3 = torch.real((-(alpha+beta)-1j*sqrt_3*(alpha-beta))/2)
    return torch.cat([lr1,lr2,lr3, torch.zeros_like(lr1)],dim=-1)


def quadratic_solver(c,d,e):
    D = torch.sqrt(d ** 2 - 4 * e * c)
    x1 = torch.real((-d + D)/(2*c))
    x2 = torch.real((-d - D)/(2*c))
    return torch.cat([x1,x2,torch.zeros_like(x1),torch.zeros_like(x1)], dim=-1)

def polynomial_solver(a,b,c,d,e):
    return torch.where(a!=0, 
                       quartic_solver(a,b,c,d,e), torch.where(
                        b!=0,
                        cubic_solver(b,c,d,e), torch.where(
                        c!=0,
                        quadratic_solver(c,d,e), torch.where(
                        d!=0, 
                        (e/d).repeat((len(a.shape)-1)*(1,)+(4,)), 
                        e.repeat((len(a.shape)-1)*(1,)+(4,))))))


def optimal_lr(A, x):
    """Calculate optimal learning rate based on matrix A and vector x."""
    # Compute intermediate vectors
    u = A(x)
    v = A(u)
    w = A(v)
    
    # Compute inner products
    a_0 = torch.einsum('ij,ij->i', x, x).unsqueeze(-1)
    a_1 = torch.einsum('ij,ij->i', u, x).unsqueeze(-1)
    a_2 = torch.einsum('ij,ij->i', v, x).unsqueeze(-1)
    a_3 = torch.einsum('ij,ij->i', w, x).unsqueeze(-1)
    a_4 = torch.einsum('ij,ij->i', w, u).unsqueeze(-1)
    a_5 = torch.einsum('ij,ij->i', w, v).unsqueeze(-1)
    a_6 = torch.einsum('ij,ij->i', w, w).unsqueeze(-1)

    # Calculate r_0, r_1, r_2
    r_0 = 4 * a_2 / (a_1 * a_1) - 2 * a_1 / (a_1 * a_0) - 2 * a_3 / (a_1 * a_2) - 2 * a_3 / (a_1 * a_2) + a_4 / (a_2 * a_2) + a_2 / (a_0 * a_2) - 2 * a_1 / (a_1 * a_0) + a_0 / (a_0 * a_0) + a_2 / (a_0 * a_2)
    r_1 = 4 * a_3 / (a_1 * a_1) - 2 * a_2 / (a_1 * a_0) - 2 * a_4 / (a_1 * a_2) - 2 * a_4 / (a_1 * a_2) + a_5 / (a_2 * a_2) + a_3 / (a_0 * a_2) - 2 * a_2 / (a_1 * a_0) + a_1 / (a_0 * a_0) + a_3 / (a_0 * a_2)
    r_2 = 4 * a_4 / (a_1 * a_1) - 2 * a_3 / (a_1 * a_0) - 2 * a_5 / (a_1 * a_2) - 2 * a_5 / (a_1 * a_2) + a_6 / (a_2 * a_2) + a_4 / (a_0 * a_2) - 2 * a_3 / (a_1 * a_0) + a_2 / (a_0 * a_0) + a_4 / (a_0 * a_2)

    # Calculate q and p
    q_1 = 2 * a_2 / a_1 - a_1 / a_0 - a_3 / a_2
    q_2 = 2 * a_3 / a_1 - a_2 / a_0 - a_4 / a_2
    p_0 = a_0
    p_1 = a_1
    p_2 = a_2

    a = r_0 * r_1 * q_2 - 2 * r_0 * q_1 * r_2
    b = p_0 * r_1 * r_2 - 2 * p_1 * r_0 * r_2 + p_2 * r_0 * r_1 - 2 * q_1 * q_2 * r_0
    c = 3 * p_0 * r_1 * q_2 - 3 * r_0 * p_1 * q_2
    d = 2 * p_0 * q_1 * q_2 + 2 * p_0 * r_1 * p_2 - p_0 * p_1 * r_2 - r_0 * p_1 * p_2
    e = 2 * p_0 * q_1 * p_2 - p_0 * p_1 * q_2
    lr = polynomial_solver(a,b,c,d,e)
    g = a_1 + 2 * lr * q_1 + lr ** 2 * r_1
    f = a_0 + lr ** 2 * r_0
    h = a_2 + 2 * lr * q_2 + lr ** 2 * r_2
    eigenness = g**2/(f*h)
    n = torch.argmax(eigenness, dim=-1, keepdim=True)
    lr = torch.gather(lr, -1, n)
    return torch.gather(eigenness, -1, n), (-lr / a_0) * x + (2 * lr / a_1) * u + (-lr / a_2) * v, n, lr


def update_vector(x, v):
    x=x+v
    return x/torch.linalg.norm(x, dim=-1).unsqueeze(-1)

def grad_ascend_lr(A,x,threshold,steps_already,steps_max):
    e,v,i,l = optimal_lr(A,x)
    x = update_vector(x, v)
    steps = steps_already+1
    cond = (e<threshold) & (steps<steps_max)
    while cond.any():
        x = torch.where(cond, update_vector(x, v), x)
        steps = torch.where(cond, steps+1, steps)
        e,v,i,l=optimal_lr(A,x)
        cond=(e<threshold)&(steps<steps_max)
    f=A(x)
    return x, torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1).unsqueeze(-1)**2, (f-(torch.einsum('ij,ij->i',f,x)/torch.linalg.norm(x,dim=-1)).unsqueeze(-1)**2*x)/torch.linalg.norm(f, dim=-1).unsqueeze(-1),steps

import torch.nn as nn

def add(x,y):
    return x+y

class modified(nn.Module):
    def __init__(self,expr,A,batch_size):
        super(modified, self).__init__()
        self.parts=[]
        for arg in expr.args:
            self.parts.append(modified(arg,A,batch_size))
        if expr.is_Number:
            self.param = nn.Parameter(torch.random.uniform((batch_size,1))*float(expr))
        self.f=torch.rand(batch_size,1).cuda()
        self.t=(torch.rand(batch_size,1).cuda()*torch.max(torch.cat([self.f**2,(1-self.f)**2],dim=1),dim=1)[0].unsqueeze(-1)).cuda()
        self.z=torch.max(torch.cat([self.t**2,(self.t-self.f**2)**2,(self.t-(1-self.f)**2)**2], dim=1), dim=1)[0].unsqueeze(-1)
        if expr == sympy.Symbol('t'):
            self.A=A
        self.expr=expr
    def forward_1(self,x):
            y=self.A(x)-self.f*x
            y=self.A(y)-self.f*y
            y=self.t*x-y
            z=self.A(y)-self.f*y
            z=self.A(z)-self.f*z
            y=self.t*y-z
            return self.z*x-y

        
    def forward(self,x):
        y = self.forward_1(x)
        y = self.forward_1(y)
        y = self.forward_1(y)
        return self.forward_1(y) 
        if self.expr == sympy.Symbol('t'):
            return self.A(x)

        if self.expr.is_Number:
            return self.param*x

        if self.expr.is_Add:
            # Initialize res as float zeros to avoid dtype conflicts
            res = torch.zeros_like(x)
            for arg in self.parts:
                res += arg(x)
            return res

        if self.expr.is_Mul:
            # product: to mimic polynomial multiplication
            for arg in self.parts:
                x = arg(x)
            return x

        if self.expr.is_Pow:
            for _ in range(self.parts[1]):
                x = self.parts[0](x)
            return x
 

        
import numpy as np



batch_size=16384
matrix_size=1024
A=torch.diag(torch.tensor(range(matrix_size), dtype=torch.float32)/matrix_size)

import scipy

C = torch.tensor(scipy.stats.ortho_group.rvs(dim=matrix_size), dtype=torch.float32)
A = torch.linalg.inv(C) @ A @ C

f=nn.Linear(matrix_size,matrix_size, bias=False).cuda()
f.weight=nn.Parameter(A.cuda(), requires_grad=False)
#f.bias=nn.Parameter(torch.zeros_like(f.bias))

h = nn.Linear(matrix_size, matrix_size, bias=False).cuda()
h.weight=nn.Parameter(C.cuda(), requires_grad=False)

result=np.zeros((matrix_size,))
expr_1=t
print(expr_1)
t=[1]
with torch.cuda.device(0):
    for i in range(10000):
        x = torch.empty((batch_size,matrix_size)).normal_(mean=0,std=1).cuda()
        steps = torch.zeros((batch_size,1)).cuda()
        t[0] = x/torch.linalg.norm(x, dim=-1, keepdim=True)
        cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        while torch.any((steps<2000) & cond):
            g = modified(expr_1,f,batch_size)
            s = grad_ascend_lr(g,t[0],0.9999,steps,2000)
            t[0] = torch.where(cond, s[0], t[0])
            steps = s[-1]
            s = [1]
            cond = (torch.max(torch.abs(h(t[0])), dim=-1, keepdim=True)[0]<0.9)
        result+= torch.sum(torch.where(torch.abs(h(t[0]))>=0.9,1,0),dim=0).cpu().numpy()
        print(i, list(np.where(result>0)[0]), end='\r')
        t = [1]



t


  Q = complex_cbrt((Delta_1 + torch.sqrt(torch.tensor(Delta_1 ** 2 - 4 * Delta_0 ** 3, dtype=torch.complex64))) / 2)


2 [0, 1, 2, 3, 12, 16, 18, 31, 32, 39, 60, 77, 98, 99, 113, 114, 115, 118, 120, 124, 155, 160, 163, 235, 237, 701, 787, 813, 839, 858, 873, 890, 896, 907, 935, 945, 948, 970, 987, 997, 1004, 1018, 1019, 1020, 1021, 1022, 1023]

KeyboardInterrupt: 

In [5]:
print(result[np.where(result>0)])

[47. 13.  5.  2.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.  1.  1.  1.  1.  1.  3.  5.  9. 48.]


In [6]:
print(torch.sum(steps))

tensor(1185432., device='cuda:0')
