In [1]:
import numpy as np
import pandas as pd
import torch
from scipy.linalg import kron
from scipy.sparse.linalg import svds

In [2]:
from numba import jit

In [3]:
from numba import cuda

In [10]:
import pickle
with open("list_of_grads", "rb") as fp:   #Pickling
    list_of_grads = pickle.load(fp)

list_of_grads = torch.stack(list_of_grads, dim=0)

In [11]:
list_of_grads.shape

torch.Size([625, 192, 768])

In [12]:
list_of_grads[6]

tensor([[ 5.6783e-04,  7.5972e-05,  5.4008e-06,  ...,  2.3077e-04,
          0.0000e+00,  3.9983e-04],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 3.2776e-04,  5.8912e-05,  1.5232e-05,  ...,  1.0055e-04,
          0.0000e+00,  1.1171e-04],
        ...,
        [-1.3369e-04, -2.2792e-05,  1.4925e-04,  ...,  1.3892e-04,
          0.0000e+00,  1.1369e-04],
        [ 3.4120e-04,  1.5226e-04,  0.0000e+00,  ...,  1.5480e-04,
          0.0000e+00,  9.4895e-05],
        [-2.7674e-04,  0.0000e+00,  0.0000e+00,  ..., -6.0975e-05,
          0.0000e+00, -1.7622e-04]])

In [6]:
m = list_of_grads[0].shape[0]
n = list_of_grads[0].shape[1]


In [66]:
import numba

print (m, n)

from scipy.sparse.linalg import svds, LinearOperator


grad_vectors = torch.stack([grad.T.reshape(-1) for grad in list_of_grads])
print (grad_vectors.shape)
d_grad_vectors = cuda.to_device(grad_vectors)



@cuda.jit(device=True)
def getitem_fmatrices_permuted(d_grad_vectors, p, q):
    reduce_init_val = 0.0
    for idx in range(d_grad_vectors.shape[0]):
        #reduce_init_val+= d_grad_vectors[idx][p]*d_grad_vectors[idx][q]
        elem = d_grad_vectors[idx]
        reduce_init_val += elem[p] * elem[q]

    return reduce_init_val/d_grad_vectors.shape[0]


@cuda.jit
def matvec_kernel1(d_grad_vectors, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*getitem_fmatrices_permuted(d_grad_vectors, (pd)*n +qd, (pn)*n + qn)#d_grad_vector[(pd)*n +qd]*d_grad_vector[(pn)*n + qn] 
        res[idx] = sum_

@cuda.jit
def rmatvec_kernel1(d_grad_vectors, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx,elem_n = cuda.grid(2)
    elem_n = np.int32(elem_n)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        pd, pn = divmod(np.int32(idx), n)
        for jnd in range(x.size):
            qd, qn = divmod(jnd, m)
            #jnd, idx
            sum_ += x[jnd]*d_grad_vectors[elem_n][(jnd//m)*n +pd]*d_grad_vectors[elem_n][(jnd%m)*n + pn] 
            #sum_ += x[jnd]*getitem_fmatrices_permuted(d_grad_vectors, (jnd//m)*n +pd, (jnd%m)*n + pn)# d_grad_vector[(jnd//m)*n +pd]*d_grad_vector[(jnd%m)*n + pn]  #grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n]
        res[idx] += sum_

def matvec_item_custom1(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 1024
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block
    print (m**2, threads_per_block, blocks_per_grid)


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel1[blocks_per_grid, threads_per_block](d_grad_vectors, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    print (res)
    return res

def rmatvec_item_custom1(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    print ("rmatvec")
    x = x.ravel()
    threads_per_block = (1024, 1)
    blocks_per_grid = ((n**2 + threads_per_block[0] - 1) // threads_per_block[0], d_grad_vectors.shape[0])


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel1[blocks_per_grid, threads_per_block](d_grad_vectors, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return (res/d_grad_vectors.shape[0])


linop_m = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom1,
    rmatvec=rmatvec_item_custom1
)
print("Operator created", flush=True)

# Compute SVD
print("Performing SVD...", flush=True)
v, s, u = svds(linop_v, k = 1, return_singular_vectors=True)
#s = np.sort(s)[::-1]

192 256
torch.Size([157, 49152])
matvec
36864 1024 36




[0. 0. 0. ... 0. 0. 0.]
Operator created
Performing SVD...
rmatvec
matvec
36864 1024 36
[ 4.2055047e-11  1.0018565e-10 -2.4735373e-11 ...  2.5676721e-13
 -5.7788943e-12  1.0330430e-11]
rmatvec
matvec
36864 1024 36
[ 1.7374263e-11  2.1050408e-11 -6.1603353e-12 ...  6.0821926e-14
 -9.7395800e-13  4.5132244e-12]
rmatvec
matvec
36864 1024 36
[ 2.8029376e-11  4.1671683e-12 -4.2693944e-13 ...  2.1808889e-14
  9.0210131e-13  2.2330094e-12]
rmatvec
matvec
36864 1024 36
[ 7.5698926e-11 -1.4508956e-12  1.2979434e-11 ...  3.1836604e-14
  1.1252966e-12 -4.1505163e-14]
rmatvec
matvec
36864 1024 36
[ 6.27822377e-11 -1.29258887e-12  1.02872667e-11 ...  1.12067795e-14
 -1.32225979e-12  8.46865113e-13]
rmatvec
matvec
36864 1024 36
[ 2.0731049e-11 -6.4113944e-12 -2.6548590e-11 ... -1.5894057e-14
 -1.8452607e-12  2.7267114e-12]
rmatvec
matvec
36864 1024 36
[ 1.9652122e-11 -2.0171139e-11 -4.8472684e-11 ... -2.3796696e-14
 -2.1295311e-13  4.4745136e-12]
rmatvec
matvec
36864 1024 36
[-9.2977423e-12 -7.61849

In [8]:
#np.li

In [69]:
import numpy as np

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)

is_pos_def(C1)

False

In [68]:
is_pos_def(B1)

False

In [None]:
d_grad_vectors.shape[0]

In [None]:
import numba

print (m, n)

from scipy.sparse.linalg import svds, LinearOperator


grad_vectors = torch.stack([grad.T.reshape(-1) for grad in list_of_grads])
print (grad_vectors.shape)
d_grad_vectors = cuda.to_device(grad_vectors)

reduce_init_val = 0.0


@cuda.jit(device=True)
def getitem_fmatrices_permuted(d_grad_vectors, p, q):
    reduce_init_val = 0.0
    for idx in range(d_grad_vectors.shape[0]):
        #reduce_init_val+= d_grad_vectors[idx][p]*d_grad_vectors[idx][q]
        elem = d_grad_vectors[idx]
        reduce_init_val += elem[p] * elem[q]

    return reduce_init_val/d_grad_vectors.shape[0]

@cuda.jit
def matvec_kernel(d_grad_vectors, x, res, m, n):
    """
    CUDA kernel to compute matvec operation (A * x).
    """
    idx = cuda.grid(1)
    
    pd, pn = divmod(np.int32(idx), m)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        for jnd in range(x.size):
            qd, qn = divmod(jnd, n)
            sum_ += x[jnd]*getitem_fmatrices_permuted(d_grad_vectors, (pd)*n +qd, (pn)*n + qn)#d_grad_vector[(pd)*n +qd]*d_grad_vector[(pn)*n + qn] 
        res[idx] = sum_
        
@cuda.jit
def rmatvec_kernel(d_grad_vectors, x, res, m, n):
    """
    CUDA kernel to compute rmatvec operation (A.T * x).
    """
    idx = cuda.grid(1)
    if idx < res.size:  # Each thread computes one element of the result
        sum_ = 0.0
        pd, pn = divmod(np.int32(idx), n)
        for jnd in range(x.size):
            #qd, qn = divmod(jnd, n)
            #jnd, idx
            sum_ += x[jnd]*getitem_fmatrices_permuted(d_grad_vectors, (jnd//m)*n +pd, (jnd%m)*n + pn)# d_grad_vector[(jnd//m)*n +pd]*d_grad_vector[(jnd%m)*n + pn]  #grad_vector[(q // m)*n +p // n]*grad_vector[(q % m)*n + p % n]
        res[idx] = sum_

def matvec_item_custom(x):
    """
    Host function to launch matvec_kernel.
    """
    print ("matvec")
    x = x.ravel()
    res = np.zeros((m**2), dtype=np.float32)
    threads_per_block = 1024
    blocks_per_grid = (m**2 + threads_per_block - 1) // threads_per_block
    print (m**2, threads_per_block, blocks_per_grid)


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((m**2), dtype=np.float32)

    matvec_kernel[blocks_per_grid, threads_per_block](d_grad_vectors, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    print (res)
    return res

def rmatvec_item_custom(x):
    """
    Host function to launch rmatvec_kernel.
    """
    res = np.zeros((n**2), dtype=np.float32)
    print ("rmatvec")
    x = x.ravel()
    threads_per_block = 1024
    blocks_per_grid = (n**2 + threads_per_block - 1) // threads_per_block


    #d_grad_vector = cuda.to_device(grad_vector)
    d_x = cuda.to_device(x)
    d_res = cuda.device_array((n**2), dtype=np.float32)

    rmatvec_kernel[blocks_per_grid, threads_per_block](d_grad_vectors, d_x, d_res, m, n)

    # Copy result back to host
    res = d_res.copy_to_host()
    return res


linop_v = LinearOperator(
    shape=(m**2, n**2),
    matvec=matvec_item_custom,
    rmatvec=rmatvec_item_custom
)
print("Operator created", flush=True)

# Compute SVD
print("Performing SVD...", flush=True)
v, s, u = svds(linop_v, k = 3, return_singular_vectors=True)
#s = np.sort(s)[::-1]

In [33]:
import numpy as np

def is_pos_def(x):
    return np.all(np.linalg.eigvals(x) > 0)

C1 = s[0] * u[-1, :].reshape(n, n)
B1 = v[:, -1].reshape(m, m)

is_pos_def(C1)

True

In [34]:
left = torch.rand(n*n)
right = torch.rand(m*m)


In [56]:
resr_1 = linop_v.matvec(left)

matvec
36864 1024 36
[ 3.8428084e-06  9.4785935e-07  6.1557512e-06 ... -2.2213476e-08
  2.0142468e-07  7.8174253e-07]


In [58]:
resr_2 = linop_m.matvec(left)

matvec
36864 1024 36
[ 3.8428084e-06  9.4785935e-07  6.1557512e-06 ... -2.2213476e-08
  2.0142468e-07  7.8174253e-07]


In [None]:
resr_2 = linop_m.matvec(left)

In [60]:
np.allclose(resr_1, resr_2, atol=1e-08)

True

In [51]:
resr_11 = linop_v.rmatvec(right)

rmatvec


In [63]:
resr_22 =linop_m.rmatvec(right)

rmatvec


In [53]:
resr_11

array([ 3.4164863e-05, -2.5415093e-06,  5.4153938e-06, ...,
       -9.7085467e-06, -4.3056963e-07,  6.9029455e-05], dtype=float32)

In [64]:
resr_22

array([ 3.4168224e-05, -2.5402694e-06,  5.4232087e-06, ...,
       -9.7085458e-06, -4.3056929e-07,  6.9029462e-05], dtype=float32)

In [65]:
np.allclose(resr_11, resr_22, atol=1e-06)

False