In [1]:
import numpy as np
import pandas as pd
import torch
from scipy.linalg import kron
from scipy.sparse.linalg import svds

## 1. Check decomposition of torch row-wise permuted matrix:

In [6]:

def reshape_and_permute(A, m, n):
    # Reshape to (m, n, m, n) in row-major order (default in PyTorch)
    A_reshaped = A.reshape(m, n, m, n)
    
    # Permute the axes (0, 2, 1, 3) to interleave dimensions
    A_permuted = A_reshaped.permute(0, 2, 1, 3)
    
    # Reshape to (m*m, n*n) in column-major order
    RA = A_permuted.reshape(m * m, n * n)  # Transpose for column-major layout
    
    return RA

In [3]:
m = 5
n = 10

B = torch.rand(m, m)
B = B + B.T
C = torch.rand(n, n)
C = C + C.T

A = torch.kron(B, C)

noise1=torch.randn(A.size())*0.01
A1 = reshape_and_permute(A, m, n)


def getitem_fmatrix_permuted(matrix, p, q):
    return matrix[p][q]

from scipy.sparse.linalg import svds, LinearOperator

#@jit
def matvec_item_custom(x): #m**2\times n**2, m**2 -> n**2 
    #print ("matvec_item_custom")
    res =np.zeros((m**2))
    for ind in range(m**2):
        sum_ = 0.0
        for jnd, elem in enumerate(x):
            sum_+= getitem_fmatrix_permuted(A1, ind, jnd)*elem
        res[ind] = sum_
    return res

#@jit
def rmatvec_item_custom(x): # A.T*y
    res = np.zeros((n**2))
    for ind in range(n**2):
        #if (ind%10000) == 0:
            #print (ind, "from", n**2)
        sum_ = 0.0
        for jnd, elem in enumerate(x):
            #print ("r", jnd, ind)
            sum_+= getitem_fmatrix_permuted(A1, jnd, ind)*elem
        res[ind] = sum_
    return res
    

print ("start to create operator", flush = True)


linop_v = LinearOperator(
                shape=(m**2,n**2),
                matvec=matvec_item_custom,
                rmatvec=rmatvec_item_custom
                )


print ("operator created", flush = True)

print ("svds...", flush = True)

v, s, u = svds(linop_v, k = 3, return_singular_vectors=True)


start to create operator
operator created
svds...


In [4]:
s 

array([3.67000299e-07, 6.04088830e-07, 6.48285812e+01])

In [5]:
C1 = s[-1] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)
print (np.linalg.norm(A-np.kron(B1, C1))/np.linalg.norm(A))


1.2854315783118753e-07


## 2. Check decomposition of torch column-wise permuted matrix with parrallel:

In [15]:
def reshape_and_permute(A, m, n):
    # Reshape to (m, n, m, n) in row-major order (default in PyTorch)
    A_reshaped = A.reshape(m, n, m, n)
    
    # Permute the axes (0, 2, 1, 3) to interleave dimensions
    A_permuted = A_reshaped.permute(0, 2, 1, 3)
    
    # Reshape to (m*m, n*n) in column-major order
    RA = A_permuted.reshape(m * m, n * n)  # Transpose for column-major layout
    
    return RA

In [9]:
import torch
import cupy as cp
import numpy as np
from cupyx.scipy.sparse.linalg import svds, LinearOperator
import numba
from numba import jit
from numba import cuda

#@cuda.jit

# Load the tensor and convert to CuPy array
#a = cp.asarray(torch.load("last_one.pt")) old
import numpy as np
from numba import cuda
from scipy.sparse.linalg import svds, LinearOperator
import torch

m = 500
n = 100

B = torch.rand(m, m)
B = B + B.T
C = torch.rand(n, n)
C = C + C.T

A = torch.kron(B, C)

noise1=torch.randn(A.size())*0.01
A1 = reshape_and_permute(A, m, n)




torch.manual_seed(42)

#import numpy as np
d_A1 = cuda.to_device(A1)


In [10]:
@cuda.jit
def matvec_kernel(d_A1, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    #pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*d_A1[np.int32(idx)][jnd]
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel(d_A1, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        #pd, pn = divmod(np.int32(idx), m)
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            #jnd, idx
            sum_ += x[jnd]*d_A1[jnd][np.int32(idx)]
        res[idx] = sum_

def matvec_item_custom(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 256
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block

    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel[blocks_per_grid, threads_per_block](d_A1, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

def rmatvec_item_custom(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    x = x.ravel()
    threads_per_block = 256
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block

    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel[blocks_per_grid, threads_per_block](d_A1, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res


In [18]:
linop_v = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom,
    rmatvec=rmatvec_item_custom
)
print("Operator created", flush=True)

# Compute SVD
print("Performing SVD...", flush=True)
v, s, u = svds(linop_v, k = 10, return_singular_vectors=True)
s = np.sort(s)[::-1]

matvec
Operator created
Performing SVD...
matvec




matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec


In [19]:
s

array([5.8762172e+04, 4.5752557e-05, 2.8291899e-05, 2.7793820e-05,
       2.7698212e-05, 2.7515733e-05, 2.6975136e-05, 2.6144406e-05,
       2.4383276e-05, 1.7502996e-05], dtype=float32)

In [20]:
C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)
np.linalg.norm(A-np.kron(B1, C1))/np.linalg.norm(A)


2.8999048e-07

## 3. Check the correspondance between vector and matrix A1

### 3.2 Vector of ints

where

A = torch.outer(grad_vector, grad_vector)

A = permute_and_reshape(A1)


In [41]:
import numpy as np
import torch
grad_matrix = torch.tensor(np.array([[1, 2, 3, 4],[5, 6, 7, 8], [9, 10, 11, 12],[10, 20, 30, 40],[50, 60, 70, 80]]))
grad_vector = grad_matrix.reshape(grad_matrix.shape[0]*grad_matrix.shape[1])
#grad_vector = grad_matrix.T.reshape(-1)


In [42]:
m,n = grad_matrix.shape

In [36]:
fisher_matrix = torch.outer(grad_vector,grad_vector)

grad_vector = grad_matrix.T.reshape(-1)

# Same as in
    # def reshape_and_permute(A, m, n):
    # Reshape to (m, n, m, n) in row-major order (default in PyTorch)
    # A_reshaped = A.reshape(m, n, m, n)
    
    # Permute the axes (0, 2, 1, 3) to interleave dimensions
    # A_permuted = A_reshaped.permute(0, 2, 1, 3)
    
    # Reshape to (m*m, n*n) in column-major order
    #RA = A_permuted.reshape(m * m, n * n)  # Transpose for column-major layout
    
reshaped1 = fisher_matrix.reshape(m, n, m, n) # i, j -> k = j //m, l = j % m <=> j = k*n + l , i = i (i - number of row, j - number of column) 
reshaped2 = reshaped1.permute(0, 2, 1, 3) # i j k l -> i k j l
#i, k, l => k, i, l <=> given: i, k, l,  reshaped1: k, i, l, fisher_matrix: k, i*n + l
reshaped3 = reshaped2.reshape(m * m, n * n)


In [43]:

for p in range(reshaped3.shape[0]):
    for q in range(reshaped3.shape[1]):
            #print (i, j, "__", i//n, i%n, j//n, j%n)
            assert (reshaped3[p][q] == fisher_matrix[(p//m)*n +q//n, (p%m)*n + q%n]), print (p, q, "__", p//m, p%m, q//n, q%n)

In [None]:
for p in range(reshaped3.T.shape[0]):
    for q in range(reshaped3.T.shape[1]):
            #print (i, j, "__", i//n, i%n, j//n, j%n)
            assert reshaped3.T[p][q] == grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n], print (p, q, "__", q//m, q%m, p//n, p%n, reshaped3.T[p][q], grad_vector[(q // m)*n +p // n], grad_vector[(q % m)*n + p % n], "X", grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n])

reshaped3[p][q] == fisher_matrix[(p // m)*n +q // n, (p % m)*n + q % n] ==grad_vector[(p // m)*n +q // n] $\times$ 
grad_vector[(p % m)*n + q % n]

In [None]:
reshaped3[q][p] == fisher_matrix[(q // m)*n +p // n, (q % m)*n + p % n] ==grad_vector[(q // m)*n +p // n] $\times$ 
grad_vector[(q % m)*n + p % n]

### 3.2 ...and random vector

In [2]:
m,n = (5,4)

In [5]:
m,n = (5,4)
import numpy as np
import torch
grad_matrix = torch.randn(m,n)

#grad_vector = grad_matrix.reshape(grad_matrix.shape[0]*grad_matrix.shape[1])
grad_vector = grad_matrix.T.reshape(-1)
fisher_matrix = torch.outer(grad_vector,grad_vector)
permuted_fisher_matrix = reshape_and_permute(fisher_matrix, m, n)

In [8]:
for p in range(permuted_fisher_matrix.shape[0]):
    for q in range(permuted_fisher_matrix.shape[1]):
            assert (permuted_fisher_matrix[p][q] == fisher_matrix[(p//m)*n +q//n, (p%m)*n + q%n]), print (p, q, "__", p//m, p%m, q//n, q%n)
            assert (permuted_fisher_matrix[p][q] == grad_vector[(p//m)*n +q//n]*grad_vector[(p%m)*n + q%n]), print (p, q, "__", p//m, p%m, q//n, q%n)

In [None]:
for p in range(permuted_fisher_matrix.T.shape[0]):
    for q in range(permuted_fisher_matrix.T.shape[1]):
            assert permuted_fisher_matrix.T[p][q] == grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n], print (p, q, "__", q//m, q%m, p//n, p%n, permuted_fisher_matrix.T[p][q], grad_vector[(q // m)*n +p // n], grad_vector[(q % m)*n + p % n], "X", grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n])

In [None]:
p = d_A1[jnd] q = [np.int32(idx)]  (ind // m)*n + idx//n  (ind%m)*n + idx%n  rmatvec
p = d_A1[np.int32(idx)] q = [jnd]  (idx // m)*n + ind//n   (idx%m)*n + ind%n   matvec

In [None]:
grad_vector[(p//m)*n +q//n]*grad_vector[(p%m)*n + q%n])

## 4. Check the parrallel computing in case of full matrix A1 and grad_vector

In [1]:
import torch
import cupy as cp
import numpy as np
from scipy.sparse.linalg import svds, LinearOperator
import numba
from numba import jit
from numba import cuda

In [6]:
def reshape_and_permute(A, m, n):
    # Reshape to (m, n, m, n) in row-major order (default in PyTorch)
    A_reshaped = A.reshape(m, n, m, n)
    
    # Permute the axes (0, 2, 1, 3) to interleave dimensions
    A_permuted = A_reshaped.permute(0, 2, 1, 3)
    
    # Reshape to (m*m, n*n) in column-major order
    RA = A_permuted.reshape(m * m, n * n)  # Transpose for column-major layout
    
    return RA

In [26]:
import numpy as np
import torch
torch.manual_seed(42)


m = 50
n = 150

B = torch.rand(m, n)
#B = B + B.T


grad_matrix = B


#m,n = grad_matrix.shape

#grad_vector = torch.rand(m*n)#grad_matrix.reshape(grad_matrix.shape[0]*grad_matrix.shape[1])
grad_vector = grad_matrix.T.reshape(-1)
#grad_vector = grad_matrix.reshape(grad_matrix.shape[0]*grad_matrix.shape[1])
fisher_matrix = torch.outer(grad_vector,grad_vector)
permuted_fisher_matrix = reshape_and_permute(fisher_matrix, m, n)



In [27]:

d_A1 = cuda.to_device(permuted_fisher_matrix)

In [28]:
d_grad_vector = cuda.to_device(grad_vector)



In [29]:
d_A1 = cuda.to_device(permuted_fisher_matrix)
@cuda.jit
def matvec_kernel(d_A1, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    #pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*d_A1[np.int32(idx)][jnd]
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel(d_A1, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        #pd, pn = divmod(np.int32(idx), m)
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            #jnd, idx
            sum_ += x[jnd]*d_A1[jnd][np.int32(idx)]
        res[idx] = sum_

def matvec_item_custom(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 10
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block

    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel[blocks_per_grid, threads_per_block](d_A1, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

def rmatvec_item_custom(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    x = x.ravel()
    threads_per_block = 10
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block

    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel[blocks_per_grid, threads_per_block](d_A1, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res


In [30]:
@cuda.jit
def matvec_kernel_v(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*d_grad_vector[(pd)*n +qd]*d_grad_vector[(pn)*n + qn] 
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel_v(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    idx_d, idx_m = divmod(np.int32(idx), n)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            #jnd, idx
            jd, jn = divmod(jnd, m)
            sum_ += x[jnd]*d_grad_vector[jd*n +idx_d]*d_grad_vector[jn*n + idx_m]
            #sum_ += x[jnd]*d_grad_vector[(jnd//n)*n +pd]*d_grad_vector[(jnd%n)*n + pn] 
        res[idx] = sum_

def matvec_item_custom_v(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 10
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel_v[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

def rmatvec_item_custom_v(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    x = x.ravel()
    threads_per_block = 10
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel_v[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

In [31]:
@cuda.jit
def matvec_kernel_v1(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*d_grad_vector[(pd)*n +qd]*d_grad_vector[(pn)*n + qn] 
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel_v1(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        pd, pn = divmod(np.int32(idx), m)
        for jnd in range(x.size):
            sum_ += x[jnd]*d_grad_vector[(jnd//m)*n +pd]*d_grad_vector[(jnd%m)*n + pn] 
        res[idx] = sum_

def matvec_item_custom_v1(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 10
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel_v1[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

def rmatvec_item_custom_v1(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    print ("rmatvec")
    x = x.ravel()
    threads_per_block = 10
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel_v1[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

In [32]:
linop_m = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom,
    rmatvec=rmatvec_item_custom
)
print("Operator created", flush=True)



matvec
Operator created


In [33]:
linop_v = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom_v,
    rmatvec=rmatvec_item_custom_v
)
print("Operator created", flush=True)



matvec
Operator created


In [34]:
linop_v1 = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom_v1,
    rmatvec=rmatvec_item_custom_v1
)
print("Op")

matvec
Op


In [35]:
n, m

(150, 50)

In [36]:
left = torch.rand(n*n)
right = torch.rand(m*m)


In [37]:
right.shape

torch.Size([2500])

In [38]:
resr_1 = linop_v.rmatvec(right)
resr_2 = linop_m.rmatvec(right)
resr_3 = linop_v1.rmatvec(right)

rmatvec


In [39]:
resr_2

array([257.7887 , 266.3341 , 249.62654, ..., 286.33588, 342.83432,
       351.9065 ], dtype=float32)

In [40]:
np.allclose(resr_1, resr_2)

True

In [41]:
np.allclose(resr_3, resr_2)

False

In [42]:
resl_1 = linop_v1.matvec(left)
resl_2 = linop_m.matvec(left)

matvec
matvec


In [43]:
np.allclose(resl_1, resl_2)

True

In [44]:
# Compute SVD
print("Performing SVD on linop_v...", flush=True)
v, s, u = svds(linop_v1, k = 10, return_singular_vectors=True)
s = np.sort(s)[::-1]
print (s)

C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)
np.linalg.norm(fisher_matrix-np.kron(B1, C1))/np.linalg.norm(fisher_matrix)


Performing SVD on linop_v...
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
mat

0.6638949

In [45]:
# Compute SVD
print("Performing SVD on linop_m...", flush=True)
v, s, u = svds(linop_m, k = 10, return_singular_vectors=True)
s = np.sort(s)[::-1]
print (s)

Performing SVD on linop_m...
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
[1896.4788   231.66359  231.66359  227.55841  227.55836  225.07861
  225.07849  215.91583  215.91582  207.23445]


### C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)
np.linalg.norm(fisher_matrix-np.kron(B1, C1))/np.linalg.norm(fisher_matrix)


In [46]:
print("Performing SVD on linop_m...", flush=True)
v, s, u = svds(linop_v, k = 10, return_singular_vectors=True)
s = np.sort(s)[::-1]
print (s)

Performing SVD on linop_m...
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
[1896.4789   231.66362  231.6636   227.55838  227.55836  225.07848
  225.07845  215.9159   215.91579  207.23444]


In [47]:
C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)
np.linalg.norm(fisher_matrix-np.kron(B1, C1))/np.linalg.norm(fisher_matrix)


0.6521983

## 5. Decompose gradient tensor

In [1]:
import torch
import cupy as cp
import numpy as np
from scipy.sparse.linalg import svds, LinearOperator
import numba
from numba import jit
from numba import cuda

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f87fb96a5f0>

In [3]:
import numpy as np

#d_res = cuda.device_array((n**2), dtype=np.float32)

a = torch.load("last_one.pt")
m = a.shape[0]
n = a.shape[1]

#grad_vector = a.T.reshape(-1)
grad_vector = a.reshape(a.shape[0]*a.shape[1])

d_grad_vector = cuda.to_device(grad_vector)



  a = torch.load("last_one.pt")


In [4]:
@cuda.jit
def matvec_kernel(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*d_grad_vector[(pd)*n +qd]*d_grad_vector[(pn)*n + qn] 
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        pd, pn = divmod(np.int32(idx), n)
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            #jnd, idx
            sum_ += x[jnd]*d_grad_vector[(jnd//m)*n +pd]*d_grad_vector[(jnd%m)*n + pn]  #grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n]
        res[idx] = sum_

def matvec_item_custom(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 1024
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

def rmatvec_item_custom(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    print ("rmatvec")
    x = x.ravel()
    threads_per_block = 1024
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

In [None]:
# Create the operator
#print("Start to create operator", flush=True)
linop_v = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom,
    rmatvec=rmatvec_item_custom
)
print("Operator created", flush=True)

#right = torch.rand(m*m)
#resr_2 = linop_v.rmatvec(right)

print("Performing SVD...", flush=True)
v, s, u = svds(linop_v, k = 3, return_singular_vectors=True)
s = np.sort(s)[::-1]

matvec
Operator created
Performing SVD...
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec


In [8]:
s

array([0.02324359, 0.01793001, 0.01793   ], dtype=float32)

In [10]:
s

array([0.05736788, 0.01356099, 0.01356099], dtype=float32)

In [6]:
s

array([0.00679347, 0.00094499, 0.00055086], dtype=float32)

In [6]:
s

array([0.00134553, 0.00083746, 0.00066922], dtype=float32)

In [10]:
C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)


In [11]:
C1.shape

(768, 768)

In [12]:
C1 = (C1 + C1.T)/2

In [13]:
import numpy as np

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

is_pos_def(C1)

False

In [14]:
is_pos_def(B1)

False

In [21]:
C1 = (s[0] * u[-1, :]).reshape(n, n).T
B1 = v[:, -1].reshape(m, m).T

C1 = (C1 + C1.T)/2

is_pos_def(C1)

False

In [20]:
is_pos_def(B1)

False

## 6. Check the order of singular values

In [3]:
import torch
import cupy as cp
import numpy as np
from scipy.sparse.linalg import svds, LinearOperator
import numba
from numba import jit
from numba import cuda

In [4]:
import numpy as np

#d_res = cuda.device_array((n**2), dtype=np.float32)

a = torch.load("linlayer.pt")/800
m = a.shape[0]
n = a.shape[1]

#grad_vector = a.T.reshape(-1)
grad_vector = a.reshape(a.shape[0]*a.shape[1])

d_grad_vector = cuda.to_device(grad_vector)


  a = torch.load("linlayer.pt")/800


In [20]:
@cuda.jit
def matvec_kernel(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*d_grad_vector[(pd)*n +qd]*d_grad_vector[(pn)*n + qn] 
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel(d_grad_vector, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        pd, pn = divmod(np.int32(idx), n)
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            #jnd, idx
            sum_ += x[jnd]*d_grad_vector[(jnd//m)*n +pd]*d_grad_vector[(jnd%m)*n + pn]  #grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n]
        res[idx] = sum_

def matvec_item_custom(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 1
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

def rmatvec_item_custom(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    print ("rmatvec")
    x = x.ravel()
    threads_per_block = 1
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel[blocks_per_grid, threads_per_block](d_grad_vector, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

In [21]:
linop_v = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom,
    rmatvec=rmatvec_item_custom
)
print("Operator created", flush=True)

#right = torch.rand(m*m)
#resr_2 = linop_v.rmatvec(right)

print("Performing SVD...", flush=True)
v, s, u = svds(linop_v, k = 3, return_singular_vectors=True)
s = np.sort(s)[::-1]

matvec
Operator created
Performing SVD...
rmatvec
matvec




rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
matvec
rmatvec
rmatvec
rmatvec


In [22]:
s

array([0.00514201, 0.00307254, 0.00307254], dtype=float32)

In [7]:
fisher_matrix = torch.outer(grad_vector,grad_vector)
permuted_fisher_matrix = reshape_and_permute(fisher_matrix, m, n)



In [38]:
C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)
np.linalg.norm(fisher_matrix-np.kron(B1, C1))/np.linalg.norm(fisher_matrix)


0.76467013

In [8]:
d_A1 = cuda.to_device(permuted_fisher_matrix)
@cuda.jit
def matvec_kernel(d_A1, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    #pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*d_A1[np.int32(idx)][jnd]
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel(d_A1, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        #pd, pn = divmod(np.int32(idx), m)
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            #jnd, idx
            sum_ += x[jnd]*d_A1[jnd][np.int32(idx)]
        res[idx] = sum_

def matvec_item_custom(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 256
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block

    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel[blocks_per_grid, threads_per_block](d_A1, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res

def rmatvec_item_custom(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    x = x.ravel()
    threads_per_block = 256
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block

    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel[blocks_per_grid, threads_per_block](d_A1, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res


In [11]:
linop_m = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom,
    rmatvec=rmatvec_item_custom
)
print("Operator created", flush=True)

print("Performing SVD...", flush=True)
v, s, u = svds(linop_m, k = 10, return_singular_vectors=True)
s = np.sort(s)[::-1]

matvec
Operator created
Performing SVD...
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec
matvec


In [12]:
s

array([0.00514201, 0.00307254, 0.00307254, 0.00183595, 0.00151782,
       0.00151782, 0.00116446, 0.00116446, 0.00096295, 0.00096295],
      dtype=float32)

In [13]:
C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)
np.linalg.norm(fisher_matrix-np.kron(B1, C1))/np.linalg.norm(fisher_matrix)


0.76467013